Skip to content

Commit 2e551e9

Browse files
committed
Support cirq classical control when lowering a circuit to squin (#315)
1 parent 528418e commit 2e551e9

File tree

2 files changed

+129
-4
lines changed

2 files changed

+129
-4
lines changed

src/bloqade/squin/cirq/lowering.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import cirq
66
from kirin import ir, lowering
77
from kirin.rewrite import Walk, CFGCompactify
8-
from kirin.dialects import py, ilist
8+
from kirin.dialects import py, scf, ilist
99

1010
from .. import op, noise, qubit
1111

@@ -150,10 +150,79 @@ def lower_measurement(
150150
):
151151
if len(node.qubits) == 1:
152152
qbit = self.lower_qubit_getindex(state, node.qubits[0])
153-
return state.current_frame.push(qubit.MeasureQubit(qbit))
153+
stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
154+
else:
155+
qbits = self.lower_qubit_getindices(state, node.qubits)
156+
stmt = state.current_frame.push(qubit.MeasureQubitList(qbits))
154157

155-
qbits = self.lower_qubit_getindices(state, node.qubits)
156-
return state.current_frame.push(qubit.MeasureQubitList(qbits))
158+
key = node.gate.key
159+
if isinstance(key, cirq.MeasurementKey):
160+
key = key.name
161+
162+
state.current_frame.defs[key] = stmt.result
163+
return stmt
164+
165+
def visit_ClassicallyControlledOperation(
166+
self, state: lowering.State[CirqNode], node: cirq.ClassicallyControlledOperation
167+
):
168+
conditions: list[ir.SSAValue] = []
169+
for outcome in node.classical_controls:
170+
key = outcome.key
171+
if isinstance(key, cirq.MeasurementKey):
172+
key = key.name
173+
measurement_outcome = state.current_frame.defs[key]
174+
175+
if measurement_outcome.type.is_subseteq(ilist.IListType):
176+
# NOTE: there is currently no convenient ilist.any method, so we need to use foldl
177+
# with a simple function that just does an or
178+
179+
def bool_op_or(x: bool, y: bool) -> bool:
180+
return x or y
181+
182+
f_code = state.current_frame.push(
183+
lowering.Python(self.dialects).python_function(bool_op_or)
184+
)
185+
fn = ir.Method(
186+
mod=None,
187+
py_func=bool_op_or,
188+
sym_name="bool_op_or",
189+
arg_names=[],
190+
dialects=self.dialects,
191+
code=f_code,
192+
)
193+
f_const = state.current_frame.push(py.constant.Constant(fn))
194+
init_val = state.current_frame.push(py.Constant(False)).result
195+
condition = state.current_frame.push(
196+
ilist.Foldl(f_const.result, measurement_outcome, init=init_val)
197+
).result
198+
else:
199+
condition = measurement_outcome
200+
201+
conditions.append(condition)
202+
203+
if len(conditions) == 1:
204+
condition = conditions[0]
205+
else:
206+
condition = state.current_frame.push(
207+
py.boolop.And(conditions[0], conditions[1])
208+
).result
209+
for next_cond in conditions[2:]:
210+
condition = state.current_frame.push(
211+
py.boolop.And(condition, next_cond)
212+
).result
213+
214+
then_stmt = self.visit(state, node.without_classical_controls())
215+
216+
assert isinstance(
217+
then_stmt, ir.Statement
218+
), f"Expected operation of classically controlled node {node} to be lowered to a statement, got type {type(then_stmt)}. \
219+
Please report this issue!"
220+
221+
# NOTE: remove stmt from parent block
222+
then_stmt.detach()
223+
then_body = ir.Block((then_stmt,))
224+
225+
return state.current_frame.push(scf.IfElse(condition, then_body=then_body))
157226

158227
def visit_SingleQubitPauliStringGateOperation(
159228
self,

test/squin/cirq/test_cirq_to_squin.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,59 @@ def test_circuit(circuit_f, run_sim: bool = False):
159159
sim = DynamicMemorySimulator()
160160
ket = sim.state_vector(kernel=kernel)
161161
print(ket)
162+
163+
164+
def test_classical_control(run_sim: bool = False):
165+
q = cirq.LineQubit.range(2)
166+
circuit = cirq.Circuit(
167+
cirq.H(q[0]),
168+
cirq.measure(q[0]),
169+
cirq.X(q[1]).with_classical_controls("q(0)"),
170+
cirq.measure(q[1]),
171+
)
172+
173+
print(circuit)
174+
175+
if run_sim:
176+
simulator = cirq.Simulator()
177+
simulator.run(circuit, repetitions=1)
178+
179+
kernel = squin.cirq.load_circuit(circuit)
180+
kernel.print()
181+
182+
183+
def test_classical_control_register():
184+
q = cirq.LineQubit.range(2)
185+
circuit = cirq.Circuit(
186+
cirq.H(q[0]),
187+
cirq.measure(q, key="test"),
188+
cirq.X(q[1]).with_classical_controls("test"),
189+
cirq.measure(q[1]),
190+
)
191+
192+
print(circuit)
193+
194+
kernel = squin.cirq.load_circuit(circuit)
195+
kernel.print()
196+
197+
198+
def test_multiple_classical_controls(run_sim: bool = False):
199+
q = cirq.LineQubit.range(2)
200+
q2 = cirq.GridQubit(0, 1)
201+
circuit = cirq.Circuit(
202+
cirq.H(q[0]),
203+
cirq.H(q2),
204+
cirq.measure(q, key="test"),
205+
cirq.measure(q2),
206+
cirq.X(q[1]).with_classical_controls("test", "q(0, 1)"),
207+
cirq.measure(q[1]),
208+
)
209+
210+
print(circuit)
211+
212+
if run_sim:
213+
sim = cirq.Simulator()
214+
sim.run(circuit)
215+
216+
kernel = squin.cirq.load_circuit(circuit)
217+
kernel.print()

0 commit comments

Comments
 (0)