Skip to content

Commit 34ca7e5

Browse files
committed
Move circuit field from EmitCirqFrame to EmitCirq
1 parent c2ab442 commit 34ca7e5

File tree

6 files changed

+69
-31
lines changed

6 files changed

+69
-31
lines changed

src/bloqade/cirq_utils/emit/base.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from kirin import ir, types, interp
77
from kirin.emit import EmitABC, EmitFrame
88
from kirin.interp import MethodTable, impl
9-
from kirin.dialects import py, func
9+
from kirin.dialects import py, func, ilist
1010
from typing_extensions import Self
1111

1212
from bloqade.squin import kernel
@@ -149,17 +149,23 @@ def main():
149149
new_func = func.Function(
150150
sym_name=sym_name, body=callable_region, signature=new_signature
151151
)
152-
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
152+
# mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
153+
mt_ = ir.Method(
154+
dialects=mt.dialects,
155+
code=new_func,
156+
sym_name=sym_name,
157+
)
153158

154159
AggressiveUnroll(mt_.dialects).fixpoint(mt_)
155-
return emitter.run(mt_, args=())
160+
emitter.initialize()
161+
emitter.run(mt_)
162+
return emitter.circuit
156163

157164

158165
@dataclass
159166
class EmitCirqFrame(EmitFrame):
160167
qubit_index: int = 0
161168
qubits: Sequence[cirq.Qid] | None = None
162-
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
163169

164170

165171
def _default_kernel():
@@ -172,19 +178,20 @@ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
172178
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
173179
void = cirq.Circuit()
174180
qubits: Sequence[cirq.Qid] | None = None
181+
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
175182

176183
def initialize(self) -> Self:
177184
return super().initialize()
178185

179186
def initialize_frame(
180-
self, code: ir.Statement, *, has_parent_access: bool = False
187+
self, node: ir.Statement, *, has_parent_access: bool = False
181188
) -> EmitCirqFrame:
182189
return EmitCirqFrame(
183-
code, has_parent_access=has_parent_access, qubits=self.qubits
190+
node, has_parent_access=has_parent_access, qubits=self.qubits
184191
)
185192

186193
def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
187-
return self.run_callable(method.code, args)
194+
return self.call(method, *args)
188195

189196
def run_callable_region(
190197
self,
@@ -198,7 +205,7 @@ def run_callable_region(
198205
# NOTE: skip self arg
199206
frame.set_values(block_args[1:], args)
200207

201-
results = self.eval_stmt(frame, code)
208+
results = self.frame_eval(frame, code)
202209
if isinstance(results, tuple):
203210
if len(results) == 0:
204211
return self.void
@@ -208,20 +215,32 @@ def run_callable_region(
208215

209216
def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
210217
for stmt in block.stmts:
211-
result = self.eval_stmt(frame, stmt)
218+
result = self.frame_eval(frame, stmt)
212219
if isinstance(result, tuple):
213220
frame.set_values(stmt.results, result)
214221

215-
return frame.circuit
222+
return self.circuit
223+
224+
def reset(self):
225+
pass
216226

217227

218228
@func.dialect.register(key="emit.cirq")
219229
class __FuncEmit(MethodTable):
220230

221231
@impl(func.Function)
222232
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
223-
emit.run_ssacfg_region(frame, stmt.body, ())
224-
return (frame.circuit,)
233+
for block in stmt.body.blocks:
234+
frame.current_block = block
235+
for s in block.stmts:
236+
frame.current_stmt = s
237+
stmt_results = emit.frame_eval(frame, s)
238+
if isinstance(stmt_results, tuple):
239+
if len(stmt_results) != 0:
240+
frame.set_values(s.results, stmt_results)
241+
continue
242+
243+
return (emit.circuit,)
225244

226245
@impl(func.Invoke)
227246
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
@@ -235,6 +254,12 @@ def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
235254
# NOTE: should only be hit if ignore_returns == True
236255
return ()
237256

257+
@impl(func.ConstantNone)
258+
def emit_constant_none(
259+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.ConstantNone
260+
):
261+
return ()
262+
238263

239264
@py.indexing.dialect.register(key="emit.cirq")
240265
class __Concrete(interp.MethodTable):
@@ -243,3 +268,19 @@ class __Concrete(interp.MethodTable):
243268
def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem):
244269
# NOTE: no support for indexing into single statements in cirq
245270
return ()
271+
272+
@interp.impl(py.Constant)
273+
def emit_constant(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant):
274+
return (stmt.value.data,) # pyright: ignore[reportAttributeAccessIssue]
275+
276+
277+
@ilist.dialect.register(key="emit.cirq")
278+
class __IList(interp.MethodTable):
279+
@interp.impl(ilist.New)
280+
def new_ilist(
281+
self,
282+
emit: EmitCirq,
283+
frame: interp.Frame,
284+
stmt: ilist.New,
285+
):
286+
return (ilist.IList(data=frame.get_values(stmt.values)),)

src/bloqade/cirq_utils/emit/gate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def hermitian(
2020
):
2121
qubits = frame.get(stmt.qubits)
2222
cirq_op = getattr(cirq, stmt.name.upper())
23-
frame.circuit.append(cirq_op.on_each(qubits))
23+
emit.circuit.append(cirq_op.on_each(qubits))
2424
return ()
2525

2626
@impl(gate.stmts.S)
@@ -36,7 +36,7 @@ def unitary(
3636
if stmt.adjoint:
3737
cirq_op = cirq_op ** (-1)
3838

39-
frame.circuit.append(cirq_op.on_each(qubits))
39+
emit.circuit.append(cirq_op.on_each(qubits))
4040
return ()
4141

4242
@impl(gate.stmts.SqrtX)
@@ -58,7 +58,7 @@ def sqrt(
5858
else:
5959
cirq_op = cirq.YPowGate(exponent=exponent)
6060

61-
frame.circuit.append(cirq_op.on_each(qubits))
61+
emit.circuit.append(cirq_op.on_each(qubits))
6262
return ()
6363

6464
@impl(gate.stmts.CX)
@@ -71,7 +71,7 @@ def control(
7171
targets = frame.get(stmt.targets)
7272
cirq_op = getattr(cirq, stmt.name.upper())
7373
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
74-
frame.circuit.append(cirq_op.on_each(cirq_qubits))
74+
emit.circuit.append(cirq_op.on_each(cirq_qubits))
7575
return ()
7676

7777
@impl(gate.stmts.Rx)
@@ -84,7 +84,7 @@ def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.RotationGat
8484
angle = turns * 2 * math.pi
8585
cirq_op = getattr(cirq, stmt.name.title())(rads=angle)
8686

87-
frame.circuit.append(cirq_op.on_each(qubits))
87+
emit.circuit.append(cirq_op.on_each(qubits))
8888
return ()
8989

9090
@impl(gate.stmts.U3)
@@ -95,10 +95,10 @@ def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.U3):
9595
phi = frame.get(stmt.phi) * 2 * math.pi
9696
lam = frame.get(stmt.lam) * 2 * math.pi
9797

98-
frame.circuit.append(cirq.Rz(rads=lam).on_each(*qubits))
98+
emit.circuit.append(cirq.Rz(rads=lam).on_each(*qubits))
9999

100-
frame.circuit.append(cirq.Ry(rads=theta).on_each(*qubits))
100+
emit.circuit.append(cirq.Ry(rads=theta).on_each(*qubits))
101101

102-
frame.circuit.append(cirq.Rz(rads=phi).on_each(*qubits))
102+
emit.circuit.append(cirq.Rz(rads=phi).on_each(*qubits))
103103

104104
return ()

src/bloqade/cirq_utils/emit/noise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def depolarize(
3434
p = frame.get(stmt.p)
3535
qubits = frame.get(stmt.qubits)
3636
cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits)
37-
frame.circuit.append(cirfq_op)
37+
interp.circuit.append(cirfq_op)
3838
return ()
3939

4040
@impl(noise.stmts.Depolarize2)
@@ -46,7 +46,7 @@ def depolarize2(
4646
targets = frame.get(stmt.targets)
4747
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
4848
cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits)
49-
frame.circuit.append(cirq_op)
49+
interp.circuit.append(cirq_op)
5050
return ()
5151

5252
@impl(noise.stmts.SingleQubitPauliChannel)
@@ -62,7 +62,7 @@ def single_qubit_pauli_channel(
6262
qubits = frame.get(stmt.qubits)
6363

6464
cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits)
65-
frame.circuit.append(cirq_op)
65+
interp.circuit.append(cirq_op)
6666

6767
return ()
6868

@@ -85,6 +85,6 @@ def two_qubit_pauli_channel(
8585
cirq_op = cirq.asymmetric_depolarize(
8686
error_probabilities=error_probabilities
8787
).on_each(cirq_qubits)
88-
frame.circuit.append(cirq_op)
88+
interp.circuit.append(cirq_op)
8989

9090
return ()

src/bloqade/cirq_utils/emit/qubit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ def measure_qubit_list(
2323
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure
2424
):
2525
qbits = frame.get(stmt.qubits)
26-
frame.circuit.append(cirq.measure(qbits))
26+
emit.circuit.append(cirq.measure(qbits))
2727
return (emit.void,)
2828

2929
@impl(qubit.Reset)
3030
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Reset):
3131
qubits = frame.get(stmt.qubits)
32-
frame.circuit.append(
32+
emit.circuit.append(
3333
cirq.ResetChannel().on_each(*qubits),
3434
)
3535
return ()

src/bloqade/cirq_utils/lowering.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,7 @@ def run(
260260
# NOTE: create a new register of appropriate size
261261
n_qubits = len(self.qreg_index)
262262
n = frame.push(py.Constant(n_qubits))
263-
self.qreg = frame.push(
264-
func.Invoke((n.result,), callee=qalloc, kwargs=())
265-
).result
263+
self.qreg = frame.push(func.Invoke((n.result,), callee=qalloc)).result
266264

267265
self.visit(state, stmt)
268266

src/bloqade/qasm2/emit/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def initialize_frame(
4848
def run_method(
4949
self, method: ir.Method, args: tuple[ast.Node | None, ...]
5050
) -> tuple[EmitQASM2Frame[StmtType], ast.Node | None]:
51-
sym_name = method.sym_name if method.sym_name is not None else ""
52-
return self.call(method, ast.Name(sym_name), *args)
51+
return self.call(method, *args)
5352

5453
def emit_block(self, frame: EmitQASM2Frame, block: ir.Block) -> ast.Node | None:
5554
for stmt in block.stmts:

0 commit comments

Comments
 (0)