Skip to content

Commit 55686c3

Browse files
Code Generation for QASM2 (#555)
Refactor QASM2's code generation to use the new Emit APIs. --------- Co-authored-by: David Plankensteiner <[email protected]>
1 parent 04e753b commit 55686c3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+685
-357
lines changed

src/bloqade/analysis/fidelity/analysis.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,15 @@ def run_analysis(
8383
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
8484
) -> tuple[ForwardFrame, Any]:
8585
self._run_address_analysis(method, no_raise=no_raise)
86-
return super().run_analysis(method, args, no_raise=no_raise)
86+
return super().run(method)
8787

8888
def _run_address_analysis(self, method: ir.Method, no_raise: bool):
8989
addr_analysis = AddressAnalysis(self.dialects)
90-
addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
90+
addr_frame, _ = addr_analysis.run(method=method)
9191
self.addr_frame = addr_frame
9292

9393
# NOTE: make sure we have as many probabilities as we have addresses
9494
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
95+
96+
def method_self(self, method: ir.Method) -> EmptyLattice:
97+
return self.lattice.bottom()

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,6 @@ def get_const_value(
5353
return hint.data
5454

5555
return None
56+
57+
def method_self(self, method: ir.Method) -> MeasureId:
58+
return self.lattice.bottom()

src/bloqade/cirq_utils/emit/base.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import cirq
66
from kirin import ir, types, interp
7-
from kirin.emit import EmitABC, EmitError, EmitFrame
7+
from kirin.emit import EmitABC, EmitFrame
88
from kirin.interp import MethodTable, impl
9-
from kirin.dialects import py, func
9+
from kirin.dialects import py, func, ilist
1010
from typing_extensions import Self
1111

1212
from bloqade.squin import kernel
@@ -102,7 +102,7 @@ def main():
102102
and isinstance(mt.code, func.Function)
103103
and not mt.code.signature.output.is_subseteq(types.NoneType)
104104
):
105-
raise EmitError(
105+
raise interp.exceptions.InterpreterError(
106106
"The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
107107
" Set `ignore_returns = True` in order to simply ignore the return values and emit a circuit."
108108
)
@@ -116,12 +116,14 @@ def main():
116116

117117
symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface)
118118
if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None:
119-
raise EmitError("The method is not a symbol, cannot emit circuit!")
119+
raise interp.exceptions.InterpreterError(
120+
"The method is not a symbol, cannot emit circuit!"
121+
)
120122

121123
sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap()
122124

123125
if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None:
124-
raise EmitError(
126+
raise interp.exceptions.InterpreterError(
125127
f"The method {sym_name} does not have a signature, cannot emit circuit!"
126128
)
127129

@@ -135,7 +137,7 @@ def main():
135137

136138
assert first_stmt is not None, "Method has no statements!"
137139
if len(args_ssa) - 1 != len(args):
138-
raise EmitError(
140+
raise interp.exceptions.InterpreterError(
139141
f"The method {sym_name} takes {len(args_ssa) - 1} arguments, but you passed in {len(args)} via the `args` keyword!"
140142
)
141143

@@ -147,17 +149,22 @@ def main():
147149
new_func = func.Function(
148150
sym_name=sym_name, body=callable_region, signature=new_signature
149151
)
150-
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
152+
mt_ = ir.Method(
153+
dialects=mt.dialects,
154+
code=new_func,
155+
sym_name=sym_name,
156+
)
151157

152158
AggressiveUnroll(mt_.dialects).fixpoint(mt_)
153-
return emitter.run(mt_, args=())
159+
emitter.initialize()
160+
emitter.run(mt_)
161+
return emitter.circuit
154162

155163

156164
@dataclass
157165
class EmitCirqFrame(EmitFrame):
158166
qubit_index: int = 0
159167
qubits: Sequence[cirq.Qid] | None = None
160-
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
161168

162169

163170
def _default_kernel():
@@ -166,23 +173,24 @@ def _default_kernel():
166173

167174
@dataclass
168175
class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
169-
keys = ["emit.cirq", "main"]
176+
keys = ("emit.cirq", "emit.main")
170177
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
171178
void = cirq.Circuit()
172179
qubits: Sequence[cirq.Qid] | None = None
180+
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
173181

174182
def initialize(self) -> Self:
175183
return super().initialize()
176184

177185
def initialize_frame(
178-
self, code: ir.Statement, *, has_parent_access: bool = False
186+
self, node: ir.Statement, *, has_parent_access: bool = False
179187
) -> EmitCirqFrame:
180188
return EmitCirqFrame(
181-
code, has_parent_access=has_parent_access, qubits=self.qubits
189+
node, has_parent_access=has_parent_access, qubits=self.qubits
182190
)
183191

184192
def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
185-
return self.run_callable(method.code, args)
193+
return self.call(method, *args)
186194

187195
def run_callable_region(
188196
self,
@@ -196,7 +204,7 @@ def run_callable_region(
196204
# NOTE: skip self arg
197205
frame.set_values(block_args[1:], args)
198206

199-
results = self.eval_stmt(frame, code)
207+
results = self.frame_eval(frame, code)
200208
if isinstance(results, tuple):
201209
if len(results) == 0:
202210
return self.void
@@ -206,33 +214,43 @@ def run_callable_region(
206214

207215
def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
208216
for stmt in block.stmts:
209-
result = self.eval_stmt(frame, stmt)
217+
result = self.frame_eval(frame, stmt)
210218
if isinstance(result, tuple):
211219
frame.set_values(stmt.results, result)
212220

213-
return frame.circuit
221+
return self.circuit
222+
223+
def reset(self):
224+
pass
225+
226+
def eval_fallback(self, frame: EmitCirqFrame, node: ir.Statement) -> tuple:
227+
return tuple(None for _ in range(len(node.results)))
214228

215229

216230
@func.dialect.register(key="emit.cirq")
217231
class __FuncEmit(MethodTable):
218232

219233
@impl(func.Function)
220234
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
221-
emit.run_ssacfg_region(frame, stmt.body, ())
222-
return (frame.circuit,)
235+
for block in stmt.body.blocks:
236+
frame.current_block = block
237+
for s in block.stmts:
238+
frame.current_stmt = s
239+
stmt_results = emit.frame_eval(frame, s)
240+
if isinstance(stmt_results, tuple):
241+
if len(stmt_results) != 0:
242+
frame.set_values(s.results, stmt_results)
243+
continue
244+
245+
return (emit.circuit,)
223246

224247
@impl(func.Invoke)
225248
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
226-
raise EmitError(
249+
raise interp.exceptions.InterpreterError(
227250
"Function invokes should need to be inlined! "
228251
"If you called the emit_circuit method, that should have happened, please report this issue."
229252
)
230253

231-
@impl(func.Return)
232-
def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
233-
# NOTE: should only be hit if ignore_returns == True
234-
return ()
235-
236254

237255
@py.indexing.dialect.register(key="emit.cirq")
238256
class __Concrete(interp.MethodTable):
@@ -241,3 +259,19 @@ class __Concrete(interp.MethodTable):
241259
def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem):
242260
# NOTE: no support for indexing into single statements in cirq
243261
return ()
262+
263+
@interp.impl(py.Constant)
264+
def emit_constant(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant):
265+
return (stmt.value.data,) # pyright: ignore[reportAttributeAccessIssue]
266+
267+
268+
@ilist.dialect.register(key="emit.cirq")
269+
class __IList(interp.MethodTable):
270+
@interp.impl(ilist.New)
271+
def new_ilist(
272+
self,
273+
emit: EmitCirq,
274+
frame: interp.Frame,
275+
stmt: ilist.New,
276+
):
277+
return (ilist.IList(data=frame.get_values(stmt.values)),)

src/bloqade/cirq_utils/emit/gate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def hermitian(
2020
):
2121
qubits = frame.get(stmt.qubits)
2222
cirq_op = getattr(cirq, stmt.name.upper())
23-
frame.circuit.append(cirq_op.on_each(qubits))
23+
emit.circuit.append(cirq_op.on_each(qubits))
2424
return ()
2525

2626
@impl(gate.stmts.S)
@@ -36,7 +36,7 @@ def unitary(
3636
if stmt.adjoint:
3737
cirq_op = cirq_op ** (-1)
3838

39-
frame.circuit.append(cirq_op.on_each(qubits))
39+
emit.circuit.append(cirq_op.on_each(qubits))
4040
return ()
4141

4242
@impl(gate.stmts.SqrtX)
@@ -58,7 +58,7 @@ def sqrt(
5858
else:
5959
cirq_op = cirq.YPowGate(exponent=exponent)
6060

61-
frame.circuit.append(cirq_op.on_each(qubits))
61+
emit.circuit.append(cirq_op.on_each(qubits))
6262
return ()
6363

6464
@impl(gate.stmts.CX)
@@ -71,7 +71,7 @@ def control(
7171
targets = frame.get(stmt.targets)
7272
cirq_op = getattr(cirq, stmt.name.upper())
7373
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
74-
frame.circuit.append(cirq_op.on_each(cirq_qubits))
74+
emit.circuit.append(cirq_op.on_each(cirq_qubits))
7575
return ()
7676

7777
@impl(gate.stmts.Rx)
@@ -84,7 +84,7 @@ def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.RotationGat
8484
angle = turns * 2 * math.pi
8585
cirq_op = getattr(cirq, stmt.name.title())(rads=angle)
8686

87-
frame.circuit.append(cirq_op.on_each(qubits))
87+
emit.circuit.append(cirq_op.on_each(qubits))
8888
return ()
8989

9090
@impl(gate.stmts.U3)
@@ -95,10 +95,10 @@ def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.U3):
9595
phi = frame.get(stmt.phi) * 2 * math.pi
9696
lam = frame.get(stmt.lam) * 2 * math.pi
9797

98-
frame.circuit.append(cirq.Rz(rads=lam).on_each(*qubits))
98+
emit.circuit.append(cirq.Rz(rads=lam).on_each(*qubits))
9999

100-
frame.circuit.append(cirq.Ry(rads=theta).on_each(*qubits))
100+
emit.circuit.append(cirq.Ry(rads=theta).on_each(*qubits))
101101

102-
frame.circuit.append(cirq.Rz(rads=phi).on_each(*qubits))
102+
emit.circuit.append(cirq.Rz(rads=phi).on_each(*qubits))
103103

104104
return ()

src/bloqade/cirq_utils/emit/noise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def depolarize(
3434
p = frame.get(stmt.p)
3535
qubits = frame.get(stmt.qubits)
3636
cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits)
37-
frame.circuit.append(cirfq_op)
37+
interp.circuit.append(cirfq_op)
3838
return ()
3939

4040
@impl(noise.stmts.Depolarize2)
@@ -46,7 +46,7 @@ def depolarize2(
4646
targets = frame.get(stmt.targets)
4747
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
4848
cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits)
49-
frame.circuit.append(cirq_op)
49+
interp.circuit.append(cirq_op)
5050
return ()
5151

5252
@impl(noise.stmts.SingleQubitPauliChannel)
@@ -62,7 +62,7 @@ def single_qubit_pauli_channel(
6262
qubits = frame.get(stmt.qubits)
6363

6464
cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits)
65-
frame.circuit.append(cirq_op)
65+
interp.circuit.append(cirq_op)
6666

6767
return ()
6868

@@ -85,6 +85,6 @@ def two_qubit_pauli_channel(
8585
cirq_op = cirq.asymmetric_depolarize(
8686
error_probabilities=error_probabilities
8787
).on_each(cirq_qubits)
88-
frame.circuit.append(cirq_op)
88+
interp.circuit.append(cirq_op)
8989

9090
return ()

src/bloqade/cirq_utils/emit/qubit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def measure_qubit_list(
2323
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure
2424
):
2525
qbits = frame.get(stmt.qubits)
26-
frame.circuit.append(cirq.measure(qbits))
26+
emit.circuit.append(cirq.measure(qbits))
2727
return (emit.void,)
2828

2929
@impl(qubit.Reset)
3030
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Reset):
3131
qubits = frame.get(stmt.qubits)
32-
frame.circuit.append(
32+
emit.circuit.append(
3333
cirq.ResetChannel().on_each(*qubits),
3434
)
3535
return ()

src/bloqade/cirq_utils/lowering.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,7 @@ def run(
260260
# NOTE: create a new register of appropriate size
261261
n_qubits = len(self.qreg_index)
262262
n = frame.push(py.Constant(n_qubits))
263-
self.qreg = frame.push(
264-
func.Invoke((n.result,), callee=qalloc, kwargs=())
265-
).result
263+
self.qreg = frame.push(func.Invoke((n.result,), callee=qalloc)).result
266264

267265
self.visit(state, stmt)
268266

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .squin2native import (
22
GateRule as GateRule,
33
SquinToNative as SquinToNative,
4-
SquinToNativePass as SquinToNativePass,
54
)

src/bloqade/pyqrack/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def task(
353353
kwargs = {}
354354

355355
address_analysis = AddressAnalysis(dialects=kernel.dialects)
356-
frame, _ = address_analysis.run_analysis(kernel)
356+
frame, _ = address_analysis.run(kernel)
357357
if self.min_qubits == 0 and any(
358358
isinstance(a, (UnknownQubit, UnknownReg)) for a in frame.entries.values()
359359
):

src/bloqade/pyqrack/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _get_interp(self, mt: ir.Method[Params, RetType]):
5151
return PyQrackInterpreter(mt.dialects, memory=DynamicMemory(options))
5252
else:
5353
address_analysis = AddressAnalysis(mt.dialects)
54-
frame, _ = address_analysis.run_analysis(mt)
54+
frame, _ = address_analysis.run(mt)
5555
if self.min_qubits == 0 and any(
5656
isinstance(a, UnknownQubit) for a in frame.entries.values()
5757
):

0 commit comments

Comments
 (0)