From 19a907e5a1b5e9cc527b05675b6e9cb22630212c Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 1 Oct 2025 09:23:47 +0200 Subject: [PATCH 01/11] Refactor cirq lowering to clifford --- src/bloqade/cirq_utils/lowering.py | 459 ++++++++++++++------------ test/cirq_utils/test_cirq_to_squin.py | 8 +- 2 files changed, 247 insertions(+), 220 deletions(-) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 76a706f5..670270cc 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -1,4 +1,3 @@ -import math from typing import Any from dataclasses import field, dataclass @@ -7,7 +6,7 @@ from kirin.rewrite import Walk, CFGCompactify from kirin.dialects import py, scf, func, ilist -from bloqade.squin import op, noise, qubit, kernel +from bloqade.squin import qubit, kernel, clifford def load_circuit( @@ -157,11 +156,16 @@ def main(): | cirq.PhasedXPowGate | cirq.PhasedXZGate | cirq.CSwapGate + | cirq.XXPowGate + | cirq.YYPowGate + | cirq.ZZPowGate + | cirq.CCXPowGate + | cirq.CCZPowGate ) @dataclass -class Squin(lowering.LoweringABC[CirqNode]): +class Squin(lowering.LoweringABC[cirq.Circuit]): """Lower a cirq.Circuit object to a squin kernel""" circuit: cirq.Circuit @@ -174,21 +178,22 @@ def __post_init__(self): qbits = sorted(self.circuit.all_qubits()) self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)} - def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid): + def lower_qubit_getindex(self, state: lowering.State[cirq.Circuit], qid: cirq.Qid): index = self.qreg_index[qid] index_ssa = state.current_frame.push(py.Constant(index)).result qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa)) return qbit_getitem.result def lower_qubit_getindices( - self, state: lowering.State[CirqNode], qids: list[cirq.Qid] + self, state: lowering.State[cirq.Circuit], qids: tuple[cirq.Qid, ...] ): qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids] - return tuple(qbits_getitem) + qbits = state.current_frame.push(ilist.New(values=qbits_getitem)) + return qbits.result def run( self, - stmt: CirqNode, + stmt: cirq.Circuit, *, source: str | None = None, globals: dict[str, Any] | None = None, @@ -232,83 +237,65 @@ def run( return region - def visit(self, state: lowering.State[CirqNode], node: CirqNode) -> lowering.Result: + def visit( + self, state: lowering.State[cirq.Circuit], node: CirqNode + ) -> lowering.Result: name = node.__class__.__name__ return getattr(self, f"visit_{name}", self.generic_visit)(state, node) - def generic_visit(self, state: lowering.State[CirqNode], node: CirqNode): + def generic_visit(self, state: lowering.State[cirq.Circuit], node: CirqNode): if isinstance(node, CirqNode): raise lowering.BuildError( f"Cannot lower {node.__class__.__name__} node: {node}" ) - raise lowering.BuildError( - f"Unexpected `{node.__class__.__name__}` node: {repr(node)} is not an AST node" - ) + raise lowering.BuildError(f"Cannot lower {node}") + + # return self.visit_Operation(state, node) - def lower_literal(self, state: lowering.State[CirqNode], value) -> ir.SSAValue: + def lower_literal(self, state: lowering.State[cirq.Circuit], value) -> ir.SSAValue: raise lowering.BuildError("Literals not supported in cirq circuit") def lower_global( - self, state: lowering.State[CirqNode], node: CirqNode + self, state: lowering.State[cirq.Circuit], node: CirqNode ) -> lowering.LoweringABC.Result: raise lowering.BuildError("Literals not supported in cirq circuit") def visit_Circuit( - self, state: lowering.State[CirqNode], node: cirq.Circuit + self, state: lowering.State[cirq.Circuit], node: cirq.Circuit ) -> lowering.Result: for moment in node: - state.lower(moment) + self.visit_Moment(state, moment) def visit_Moment( - self, state: lowering.State[CirqNode], node: cirq.Moment + self, state: lowering.State[cirq.Circuit], node: cirq.Moment ) -> lowering.Result: for op_ in node.operations: - state.lower(op_) + self.visit(state, op_) def visit_GateOperation( - self, state: lowering.State[CirqNode], node: cirq.GateOperation + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation ): - if isinstance(node.gate, cirq.MeasurementGate): - # NOTE: special dispatch here, since measurement is a gate + a qubit in cirq, - # but a single statement in squin - return self.lower_measurement(state, node) - if isinstance(node.gate, DecomposeNode): # NOTE: easier to decompose these, but for that we need the qubits too, # so we need to do this within this method for subnode in cirq.decompose_once(node): - state.lower(subnode) + self.visit(state, subnode) return - op_ = state.lower(node.gate).expect_one() - qbits = self.lower_qubit_getindices(state, node.qubits) - return state.current_frame.push(qubit.Apply(operator=op_, qubits=qbits)) + # NOTE: just forward to the appropriate method by getting the name + name = node.gate.__class__.__name__ + return getattr(self, f"visit_{name}", self.generic_visit)(state, node) def visit_TaggedOperation( - self, state: lowering.State[CirqNode], node: cirq.TaggedOperation + self, state: lowering.State[cirq.Circuit], node: cirq.TaggedOperation ): - state.lower(node.untagged) - - def lower_measurement( - self, state: lowering.State[CirqNode], node: cirq.GateOperation - ): - if len(node.qubits) == 1: - qbit = self.lower_qubit_getindex(state, node.qubits[0]) - stmt = state.current_frame.push(qubit.MeasureQubit(qbit)) - else: - qbits = self.lower_qubit_getindices(state, node.qubits) - qbits_list = state.current_frame.push(ilist.New(values=qbits)) - stmt = state.current_frame.push(qubit.MeasureQubitList(qbits_list.result)) - - key = node.gate.key - if isinstance(key, cirq.MeasurementKey): - key = key.name - - state.current_frame.defs[key] = stmt.result - return stmt + return self.visit(state, node.untagged) + # state.lower(node.untagged) def visit_ClassicallyControlledOperation( - self, state: lowering.State[CirqNode], node: cirq.ClassicallyControlledOperation + self, + state: lowering.State[cirq.Circuit], + node: cirq.ClassicallyControlledOperation, ): conditions: list[ir.SSAValue] = [] for outcome in node.classical_controls: @@ -369,211 +356,251 @@ def bool_op_or(x: bool, y: bool) -> bool: return state.current_frame.push(scf.IfElse(condition, then_body=then_body)) + def visit_MeasurementGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + cirq_qubits = node.qubits + if len(cirq_qubits) == 1: + qbit = self.lower_qubit_getindex(state, node.qubits[0]) + stmt = state.current_frame.push(qubit.MeasureQubit(qbit)) + else: + qubits = self.lower_qubit_getindices(state, node.qubits) + stmt = state.current_frame.push(qubit.MeasureQubitList(qubits)) + + # NOTE: add for classically controlled lowering + key = node.gate.key + if isinstance(key, cirq.MeasurementKey): + key = key.name + state.current_frame.defs[key] = stmt.result + + return stmt + def visit_SingleQubitPauliStringGateOperation( self, - state: lowering.State[CirqNode], + state: lowering.State[cirq.Circuit], node: cirq.SingleQubitPauliStringGateOperation, ): + if isinstance(node.pauli, cirq.IdentityGate): + # TODO: do we need an identity gate in clifford? + return + qargs = self.lower_qubit_getindices(state, (node.qubit,)) match node.pauli: case cirq.X: - op_ = op.stmts.X() + clifford_stmt = clifford.stmts.X case cirq.Y: - op_ = op.stmts.Y() + clifford_stmt = clifford.stmts.Y case cirq.Z: - op_ = op.stmts.Z() - case cirq.I: - op_ = op.stmts.Identity(sites=1) + clifford_stmt = clifford.stmts.Z case _: raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}") - state.current_frame.push(op_) - qargs = self.lower_qubit_getindices(state, [node.qubit]) - return state.current_frame.push(qubit.Apply(op_.result, qargs)) + return state.current_frame.push(clifford_stmt(qargs)) - def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate): - if abs(node.exponent) == 1: - return state.current_frame.push(op.stmts.H()) + def visit_HPowGate(self, state: lowering.State[cirq.Circuit], node: cirq.HPowGate): + qargs = self.lower_qubit_getindices(state, node.qubits) - # NOTE: decompose into products of paulis for arbitrary exponents according to _decompose_ method - # can't use decompose directly since that method requires qubits to be passed in for some reason - y_rhs = state.lower(cirq.YPowGate(exponent=0.25)).expect_one() - x = state.lower( - cirq.XPowGate(exponent=node.exponent, global_shift=node.global_shift) - ).expect_one() - y_lhs = state.lower(cirq.YPowGate(exponent=-0.25)).expect_one() - - # NOTE: reversed order since we're creating a mult stmt - m_lhs = state.current_frame.push(op.stmts.Mult(y_lhs, x)) - return state.current_frame.push(op.stmts.Mult(m_lhs.result, y_rhs)) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.H(qargs)) - def visit_XPowGate(self, state: lowering.State[CirqNode], node: cirq.XPowGate): - if abs(node.exponent == 1): - return state.current_frame.push(op.stmts.X()) - - return self.visit(state, node.in_su2()) + # NOTE: decompose into products of paulis for arbitrary exponents according to _decompose_ method + for subnode in cirq.decompose_once(node): + self.visit(state, subnode) - def visit_YPowGate(self, state: lowering.State[CirqNode], node: cirq.YPowGate): - if abs(node.exponent == 1): - return state.current_frame.push(op.stmts.Y()) + def visit_XPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + qargs = self.lower_qubit_getindices(state, node.qubits) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.X(qargs)) - return self.visit(state, node.in_su2()) + angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rx(angle.result, qargs)) - def visit_ZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZPowGate): - if node.exponent == 0.5: - return state.current_frame.push(op.stmts.S()) + def visit_YPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + qargs = self.lower_qubit_getindices(state, node.qubits) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.Y(qargs)) - if node.exponent == 0.25: - return state.current_frame.push(op.stmts.T()) + angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Ry(angle.result, qargs)) - if abs(node.exponent == 1): - return state.current_frame.push(op.stmts.Z()) + def visit_ZPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + qargs = self.lower_qubit_getindices(state, node.qubits) - # NOTE: just for the Z gate, an arbitrary exponent is equivalent to the ShiftOp - # up to a minus sign! - t = -node.exponent - theta = state.current_frame.push(py.Constant(math.pi * t)) - return state.current_frame.push(op.stmts.ShiftOp(theta=theta.result)) + if abs(node.gate.exponent) == 0.5: + adjoint = node.gate.exponent < 0 + return state.current_frame.push( + clifford.stmts.S(adjoint=adjoint, qubits=qargs) + ) - def visit_Rx(self, state: lowering.State[CirqNode], node: cirq.Rx): - x = state.current_frame.push(op.stmts.X()) - angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent)) - return state.current_frame.push(op.stmts.Rot(axis=x.result, angle=angle.result)) + if abs(node.gate.exponent) == 0.25: + adjoint = node.gate.exponent < 0 + return state.current_frame.push( + clifford.stmts.T(adjoint=adjoint, qubits=qargs) + ) - def visit_Ry(self, state: lowering.State[CirqNode], node: cirq.Ry): - y = state.current_frame.push(op.stmts.Y()) - angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent)) - return state.current_frame.push(op.stmts.Rot(axis=y.result, angle=angle.result)) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.Z(qubits=qargs)) - def visit_Rz(self, state: lowering.State[CirqNode], node: cirq.Rz): - z = state.current_frame.push(op.stmts.Z()) - angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent)) - return state.current_frame.push(op.stmts.Rot(axis=z.result, angle=angle.result)) + angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rz(angle.result, qargs)) - def visit_CXPowGate(self, state: lowering.State[CirqNode], node: cirq.CXPowGate): - x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(x, n_controls=1)) + def visit_Rx(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation): + qargs = self.lower_qubit_getindices(state, node.qubits) + angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rx(angle.result, qargs)) - def visit_CZPowGate(self, state: lowering.State[CirqNode], node: cirq.CZPowGate): - z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(z, n_controls=1)) + def visit_Ry(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation): + qargs = self.lower_qubit_getindices(state, node.qubits) + angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Ry(angle.result, qargs)) - def visit_ControlledOperation( - self, state: lowering.State[CirqNode], node: cirq.ControlledOperation - ): - return self.visit_GateOperation(state, node) + def visit_Rz(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation): + qargs = self.lower_qubit_getindices(state, node.qubits) + angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rz(angle.result, qargs)) - def visit_ControlledGate( - self, state: lowering.State[CirqNode], node: cirq.ControlledGate + def visit_CXPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation ): - op_ = state.lower(node.sub_gate).expect_one() - n_controls = node.num_controls() - return state.current_frame.push(op.stmts.Control(op_, n_controls=n_controls)) - - def visit_XXPowGate(self, state: lowering.State[CirqNode], node: cirq.XXPowGate): - x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Kron(x, x)) - - def visit_YYPowGate(self, state: lowering.State[CirqNode], node: cirq.YYPowGate): - y = state.lower(cirq.YPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Kron(y, y)) - - def visit_ZZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZZPowGate): - z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Kron(z, z)) - - def visit_CCXPowGate(self, state: lowering.State[CirqNode], node: cirq.CCXPowGate): - x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(x, n_controls=2)) + if node.gate.exponent % 2 == 0: + return - def visit_CCZPowGate(self, state: lowering.State[CirqNode], node: cirq.CCZPowGate): - z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(z, n_controls=2)) + if node.gate.exponent % 2 != 1: + raise lowering.BuildError("Exponents of CX gate are not supported!") - def visit_BitFlipChannel( - self, state: lowering.State[CirqNode], node: cirq.BitFlipChannel - ): - x = state.current_frame.push(op.stmts.X()) - p = state.current_frame.push(py.Constant(node.p)) + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) return state.current_frame.push( - noise.stmts.PauliError(basis=x.result, p=p.result) - ) - - def visit_AmplitudeDampingChannel( - self, state: lowering.State[CirqNode], node: cirq.AmplitudeDampingChannel - ): - r = state.current_frame.push(op.stmts.Reset()) - p = state.current_frame.push(py.Constant(node.gamma)) - - # TODO: do we need a dedicated noise stmt for this? Using PauliError - # with this basis feels like a hack - noise_channel = state.current_frame.push( - noise.stmts.PauliError(basis=r.result, p=p.result) + clifford.stmts.CX(controls=control_qarg, targets=target_qarg) ) - return noise_channel - - def visit_GeneralizedAmplitudeDampingChannel( - self, - state: lowering.State[CirqNode], - node: cirq.GeneralizedAmplitudeDampingChannel, + def visit_CZPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation ): - p = state.current_frame.push(py.Constant(node.p)).result - gamma = state.current_frame.push(py.Constant(node.gamma)).result - - # NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel, - # which basically means p is the probability of the environment being in the vacuum state - prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result - one_ = state.current_frame.push(py.Constant(1)).result - p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result - prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result - - r0 = state.current_frame.push(op.stmts.Reset()).result - r1 = state.current_frame.push(op.stmts.ResetToOne()).result + if node.gate.exponent % 2 == 0: + return - probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result - ops = state.current_frame.push(ilist.New(values=(r0, r1))).result + if node.gate.exponent % 2 != 1: + raise lowering.BuildError("Exponents of CZ gate are not supported!") - noise_channel = state.current_frame.push( - noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops) + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) + return state.current_frame.push( + clifford.stmts.CZ(controls=control_qarg, targets=target_qarg) ) - return noise_channel - - def visit_DepolarizingChannel( - self, state: lowering.State[CirqNode], node: cirq.DepolarizingChannel - ): - p = state.current_frame.push(py.Constant(node.p)).result - return state.current_frame.push(noise.stmts.Depolarize(p)) - - def visit_AsymmetricDepolarizingChannel( - self, state: lowering.State[CirqNode], node: cirq.AsymmetricDepolarizingChannel + def visit_ControlledOperation( + self, state: lowering.State[cirq.Circuit], node: cirq.ControlledOperation ): - nqubits = node.num_qubits() - if nqubits > 2: - raise lowering.BuildError( - "AsymmetricDepolarizingChannel applied to more than 2 qubits is not supported!" - ) + match node.gate.sub_gate: + case cirq.X: + stmt = clifford.stmts.CX + case cirq.Y: + stmt = clifford.stmts.CY + case cirq.Z: + stmt = clifford.stmts.CZ + case _: + raise lowering.BuildError( + f"Cannot lowering controlled operation: {node}" + ) - if nqubits == 1: - p_x = state.current_frame.push(py.Constant(node.p_x)).result - p_y = state.current_frame.push(py.Constant(node.p_y)).result - p_z = state.current_frame.push(py.Constant(node.p_z)).result - params = state.current_frame.push(ilist.New(values=(p_x, p_y, p_z))).result - return state.current_frame.push(noise.stmts.SingleQubitPauliChannel(params)) - - # NOTE: nqubits == 2 - error_probs = node.error_probabilities - paulis = ("I", "X", "Y", "Z") - values = [] - for p1 in paulis: - for p2 in paulis: - if p1 == p2 == "I": - continue - - p = error_probs.get(p1 + p2, 0.0) - p_ssa = state.current_frame.push(py.Constant(p)).result - values.append(p_ssa) - - params = state.current_frame.push(ilist.New(values=values)).result - return state.current_frame.push(noise.stmts.TwoQubitPauliChannel(params)) + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) + return state.current_frame.push(stmt(control_qarg, target_qarg)) + + # def visit_BitFlipChannel( + # self, state: lowering.State[cirq.Circuit], node: cirq.BitFlipChannel + # ): + # x = state.current_frame.push(op.stmts.X()) + # p = state.current_frame.push(py.Constant(node.p)) + # return state.current_frame.push( + # noise.stmts.PauliError(basis=x.result, p=p.result) + # ) + + # def visit_AmplitudeDampingChannel( + # self, state: lowering.State[cirq.Circuit], node: cirq.AmplitudeDampingChannel + # ): + # r = state.current_frame.push(op.stmts.Reset()) + # p = state.current_frame.push(py.Constant(node.gamma)) + + # # TODO: do we need a dedicated noise stmt for this? Using PauliError + # # with this basis feels like a hack + # noise_channel = state.current_frame.push( + # noise.stmts.PauliError(basis=r.result, p=p.result) + # ) + + # return noise_channel + + # def visit_GeneralizedAmplitudeDampingChannel( + # self, + # state: lowering.State[cirq.Circuit], + # node: cirq.GeneralizedAmplitudeDampingChannel, + # ): + # p = state.current_frame.push(py.Constant(node.p)).result + # gamma = state.current_frame.push(py.Constant(node.gamma)).result + + # # NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel, + # # which basically means p is the probability of the environment being in the vacuum state + # prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result + # one_ = state.current_frame.push(py.Constant(1)).result + # p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result + # prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result + + # r0 = state.current_frame.push(op.stmts.Reset()).result + # r1 = state.current_frame.push(op.stmts.ResetToOne()).result + + # probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result + # ops = state.current_frame.push(ilist.New(values=(r0, r1))).result + + # noise_channel = state.current_frame.push( + # noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops) + # ) + + # return noise_channel + + # def visit_DepolarizingChannel( + # self, state: lowering.State[cirq.Circuit], node: cirq.DepolarizingChannel + # ): + # p = state.current_frame.push(py.Constant(node.p)).result + # return state.current_frame.push(noise.stmts.Depolarize(p)) + + # def visit_AsymmetricDepolarizingChannel( + # self, state: lowering.State[cirq.Circuit], node: cirq.AsymmetricDepolarizingChannel + # ): + # nqubits = node.num_qubits() + # if nqubits > 2: + # raise lowering.BuildError( + # "AsymmetricDepolarizingChannel applied to more than 2 qubits is not supported!" + # ) + + # if nqubits == 1: + # p_x = state.current_frame.push(py.Constant(node.p_x)).result + # p_y = state.current_frame.push(py.Constant(node.p_y)).result + # p_z = state.current_frame.push(py.Constant(node.p_z)).result + # params = state.current_frame.push(ilist.New(values=(p_x, p_y, p_z))).result + # return state.current_frame.push(noise.stmts.SingleQubitPauliChannel(params)) + + # # NOTE: nqubits == 2 + # error_probs = node.error_probabilities + # paulis = ("I", "X", "Y", "Z") + # values = [] + # for p1 in paulis: + # for p2 in paulis: + # if p1 == p2 == "I": + # continue + + # p = error_probs.get(p1 + p2, 0.0) + # p_ssa = state.current_frame.push(py.Constant(p)).result + # values.append(p_ssa) + + # params = state.current_frame.push(ilist.New(values=values)).result + # return state.current_frame.push(noise.stmts.TwoQubitPauliChannel(params)) diff --git a/test/cirq_utils/test_cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py index 6da8fe0b..81a8de14 100644 --- a/test/cirq_utils/test_cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -44,7 +44,7 @@ def controlled_gates(): return cirq.Circuit( cirq.H(q1), cirq.X(q0).controlled_by(q1), - cirq.Rx(rads=math.pi / 4).on(q0).controlled_by(q1), + cirq.Y.on(q0).controlled_by(q1), cirq.measure(q0, q1), ) @@ -90,7 +90,7 @@ def two_qubit_pow_gates(): q1 = cirq.LineQubit(1) return cirq.Circuit( - cirq.CX(q0, q1) ** 2, cirq.CZ(q0, q1) ** 0.123, cirq.measure(q0, q1) + cirq.CX(q0, q1) ** 2, cirq.CZ(q0, q1) ** -1, cirq.measure(q0, q1) ) @@ -177,8 +177,8 @@ def nested_circuit(): two_qubit_pow_gates, swap_circuit, three_qubit_gates, - noise_channels, - depolarizing_channels, + # noise_channels, + # depolarizing_channels, ], ) def test_circuit(circuit_f, run_sim: bool = False): From 06823b4e6a1f75ece48c4f249fb509108c67d8f6 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 1 Oct 2025 09:26:13 +0200 Subject: [PATCH 02/11] Remove leftover code --- src/bloqade/cirq_utils/lowering.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 670270cc..53bce1f5 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -290,7 +290,6 @@ def visit_TaggedOperation( self, state: lowering.State[cirq.Circuit], node: cirq.TaggedOperation ): return self.visit(state, node.untagged) - # state.lower(node.untagged) def visit_ClassicallyControlledOperation( self, From 4b54b39e1d7cc1b9c01b08dbd87e622fcf1dcccc Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 1 Oct 2025 11:39:45 +0200 Subject: [PATCH 03/11] Implement cirq emit for clifford dialect --- src/bloqade/cirq_utils/emit/__init__.py | 2 +- src/bloqade/cirq_utils/emit/base.py | 31 ++- src/bloqade/cirq_utils/emit/clifford.py | 90 +++++++ test/cirq_utils/test_clifford_to_cirq.py | 283 +++++++++++++++++++++++ 4 files changed, 394 insertions(+), 12 deletions(-) create mode 100644 src/bloqade/cirq_utils/emit/clifford.py create mode 100644 test/cirq_utils/test_clifford_to_cirq.py diff --git a/src/bloqade/cirq_utils/emit/__init__.py b/src/bloqade/cirq_utils/emit/__init__.py index 098d868b..98caadcb 100644 --- a/src/bloqade/cirq_utils/emit/__init__.py +++ b/src/bloqade/cirq_utils/emit/__init__.py @@ -1,3 +1,3 @@ # NOTE: just to register methods -from . import op as op, noise as noise, qubit as qubit +from . import op as op, noise as noise, qubit as qubit, clifford as clifford from .base import emit_circuit as emit_circuit diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 7adfc829..4d664782 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -137,7 +137,7 @@ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]): dialects: ir.DialectGroup = field(default_factory=_default_kernel) void = cirq.Circuit() qubits: Sequence[cirq.Qid] | None = None - _cached_circuit_operations: dict[int, cirq.CircuitOperation] = field( + _cached_invokes: dict[int, cirq.FrozenCircuit] = field( init=False, default_factory=dict ) @@ -193,12 +193,22 @@ def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function): @impl(func.Invoke) def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke): - stmt_hash = hash((stmt.callee, stmt.inputs)) - if ( - cached_circuit_op := emit._cached_circuit_operations.get(stmt_hash) - ) is not None: + try: + stmt_hash = hash( + (stmt.callee, tuple(frame.get(input) for input in stmt.inputs)) + ) + except (TypeError, interp.InterpreterError): + # NOTE: avoid unhashable types and missing keys, just don't cache them + stmt_hash = None + + if stmt_hash is not None: + cached_circuit = emit._cached_invokes.get(stmt_hash) + else: + cached_circuit = None + + if cached_circuit is not None: # NOTE: cache hit - frame.circuit.append(cached_circuit_op) + frame.circuit.append(cached_circuit.all_operations()) return () ret = stmt.result @@ -230,9 +240,8 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke): if return_stmt is not None: frame.entries[ret] = sub_frame.get(return_stmt.value) - circuit_op = cirq.CircuitOperation( - sub_circuit.freeze(), use_repetition_ids=False - ) - emit._cached_circuit_operations[stmt_hash] = circuit_op - frame.circuit.append(circuit_op) + if stmt_hash is not None: + emit._cached_invokes[stmt_hash] = sub_circuit.freeze() + + frame.circuit.append(sub_circuit.all_operations()) return () diff --git a/src/bloqade/cirq_utils/emit/clifford.py b/src/bloqade/cirq_utils/emit/clifford.py new file mode 100644 index 00000000..6eb16905 --- /dev/null +++ b/src/bloqade/cirq_utils/emit/clifford.py @@ -0,0 +1,90 @@ +import math + +import cirq +from kirin.interp import MethodTable, impl + +from bloqade.squin import clifford + +from .base import EmitCirq, EmitCirqFrame + + +@clifford.dialect.register(key="emit.cirq") +class EmitCirqOpMethods(MethodTable): + + @impl(clifford.stmts.X) + @impl(clifford.stmts.Y) + @impl(clifford.stmts.Z) + @impl(clifford.stmts.H) + def hermitian( + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.SingleQubitGate + ): + qubits = frame.get(stmt.qubits) + cirq_op = getattr(cirq, stmt.name.upper()) + frame.circuit.append(cirq_op.on_each(qubits)) + return (frame.circuit,) + + @impl(clifford.stmts.S) + @impl(clifford.stmts.T) + def unitary( + self, + emit: EmitCirq, + frame: EmitCirqFrame, + stmt: clifford.stmts.SingleQubitNonHermitianGate, + ): + qubits = frame.get(stmt.qubits) + cirq_op = getattr(cirq, stmt.name.upper()) + if stmt.adjoint: + cirq_op = cirq_op ** (-1) + + frame.circuit.append(cirq_op.on_each(qubits)) + return (frame.circuit,) + + @impl(clifford.stmts.SqrtX) + @impl(clifford.stmts.SqrtY) + def sqrt( + self, + emit: EmitCirq, + frame: EmitCirqFrame, + stmt: clifford.stmts.SqrtX | clifford.stmts.SqrtY, + ): + qubits = frame.get(stmt.qubits) + + exponent = 0.5 + if stmt.adjoint: + exponent *= -1 + + if isinstance(stmt, clifford.stmts.SqrtX): + cirq_op = cirq.XPowGate(exponent=exponent) + else: + cirq_op = cirq.YPowGate(exponent=exponent) + + frame.circuit.append(cirq_op.on_each(qubits)) + return (frame.circuit,) + + @impl(clifford.stmts.CX) + @impl(clifford.stmts.CY) + @impl(clifford.stmts.CZ) + def control( + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.ControlledGate + ): + controls = frame.get(stmt.controls) + targets = frame.get(stmt.targets) + cirq_op = getattr(cirq, stmt.name.upper()) + cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] + frame.circuit.append(cirq_op.on_each(cirq_qubits)) + return (frame.circuit,) + + @impl(clifford.stmts.Rx) + @impl(clifford.stmts.Ry) + @impl(clifford.stmts.Rz) + def rot( + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.RotationGate + ): + qubits = frame.get(stmt.qubits) + + turns = frame.get(stmt.angle) + angle = turns * 2 * math.pi + cirq_op = getattr(cirq, stmt.name.title())(rads=angle) + + frame.circuit.append(cirq_op.on_each(qubits)) + return (frame.circuit,) diff --git a/test/cirq_utils/test_clifford_to_cirq.py b/test/cirq_utils/test_clifford_to_cirq.py new file mode 100644 index 00000000..46c3c259 --- /dev/null +++ b/test/cirq_utils/test_clifford_to_cirq.py @@ -0,0 +1,283 @@ +import math +import typing + +import cirq +import pytest +from kirin.emit import EmitError +from kirin.dialects import ilist + +from bloqade import squin +from bloqade.cirq_utils import emit, emit_circuit + + +def test_pauli(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + q2 = squin.qubit.new(4) + x = squin.op.x() + y = squin.op.y() + z = squin.op.z() + squin.qubit.apply(x, q[0]) + squin.qubit.apply(y, q2[0]) + squin.qubit.apply(z, q2[3]) + + circuit = emit_circuit(main) + + print(circuit) + + qbits = circuit.all_qubits() + assert len(qbits) == 3 + assert isinstance(qbit := list(qbits)[-1], cirq.LineQubit) + assert qbit.x == 5 + + +@pytest.mark.parametrize("op_name", ["h", "s", "t", "x", "y", "z"]) +def test_basic_op(op_name: str): + @squin.kernel + def main(): + q = squin.qubit.new(1) + getattr(squin, op_name)(q) + + emit_circuit(main) + + +def test_control(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + squin.h(q[0]) + squin.cx(q[0], q[1]) + + circuit = emit_circuit(main) + + print(circuit) + + assert len(circuit) == 2 + assert circuit[1].operations[0].gate == cirq.CNOT + + +def test_custom_qubits(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + squin.h(q[0]) + squin.cx(q[0], q[1]) + + qubits = [cirq.GridQubit(0, 1), cirq.GridQubit(2, 2)] + circuit = emit_circuit(main, qubits=qubits) + + print(circuit) + + circuit_qubits = circuit.all_qubits() + assert len(circuit_qubits) == 2 + assert frozenset(qubits) == circuit_qubits + + +def test_composed_kernels(): + @squin.kernel + def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.h(q_[0]) + + @squin.kernel + def main(): + q = squin.qubit.new(2) + sub_kernel(q) + + circuit = emit_circuit(main) + + print(circuit) + + assert len(circuit) == 1 + assert len(circuit[0].operations) == 1 + assert isinstance(circuit[0].operations[0], cirq.GateOperation) + + +def test_nested_kernels(): + @squin.kernel + def sub_kernel2(q2_: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.cx(q2_[0], q2_[1]) + + @squin.kernel + def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.h(q_[0]) + sub_kernel2(q_) + + @squin.kernel + def main(): + q = squin.qubit.new(2) + sub_kernel(q) + + circuit = emit_circuit(main) + + print(circuit) + + +def test_return_value(): + @squin.kernel + def sub_kernel(): + q = squin.qubit.new(2) + squin.h(q[0]) + squin.cx(q[0], q[1]) + return q + + @squin.kernel + def main(): + q = sub_kernel() + squin.h(q[0]) + + circuit = emit_circuit(main) + + print(circuit) + + with pytest.raises(EmitError): + emit_circuit(sub_kernel) + + @squin.kernel + def main2(): + q = sub_kernel() + squin.h(q[0]) + return q + + circuit2 = emit_circuit(main2, ignore_returns=True) + print(circuit2) + + assert circuit2 == circuit + + +def test_return_qubits(): + @squin.kernel + def sub_kernel(q: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.h(q[0]) + q2 = squin.qubit.new(3) + squin.cx(q[0], q2[2]) + return q2 + + @squin.kernel + def main(): + q = squin.qubit.new(2) + q2_ = sub_kernel(q) + squin.x(q2_[0]) + + circuit = emit_circuit(main) + + print(circuit) + + +def test_measurement(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + squin.broadcast.y(q) + squin.qubit.measure(q) + + circuit = emit_circuit(main) + + print(circuit) + + +def test_adjoint(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.s(q[0]) + squin.s_adj(q[0]) + + circuit = emit_circuit(main) + print(circuit) + + +def test_u3(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.u3(0.323, 1.123, math.pi / 7, q[0]) + + circuit = emit_circuit(main) + print(circuit) + + +def test_shift(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.shift(math.pi / 7, q[0]) + + circuit = emit_circuit(main) + print(circuit) + + +def test_invoke_cache(): + @squin.kernel + def sub_kernel(q_: squin.qubit.Qubit): + squin.h(q_) + + @squin.kernel + def main(): + q = squin.qubit.new(2) + q0 = q[0] + sub_kernel(q0) + sub_kernel(q[1]) + sub_kernel(q0) + + target = emit.base.EmitCirq(main.dialects) + + circuit = target.run(main, ()) + + print(circuit) + + # caches as well as squin.h and squin.broadcast.h with the different qubits + assert len(target._cached_invokes) == 6 + + +def test_rot(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.rx(math.pi / 2, q[0]) + + circuit = emit_circuit(main) + + print(circuit) + + assert circuit[0].operations[0].gate == cirq.Rx(rads=math.pi / 2) + + +def test_additional_stmts(): + @squin.kernel + def main(): + q = squin.qubit.new(3) + squin.rot(math.pi / 4, math.pi / 2, -math.pi / 4, q[0]) + squin.sqrt_x(q[1]) + squin.sqrt_y(q[2]) + + main.print() + + circuit = emit_circuit(main) + + print(circuit) + + q = cirq.LineQubit.range(3) + expected_circuit = cirq.Circuit( + cirq.Rz(rads=math.pi / 4).on(q[0]), + cirq.Ry(rads=math.pi / 2).on(q[0]), + cirq.Rz(rads=-math.pi / 4).on(q[0]), + cirq.X(q[1]) ** 0.5, + cirq.Y(q[2]) ** 0.5, + ) + + assert circuit == expected_circuit + + +def test_return_measurement(): + + @squin.kernel + def coinflip(): + qubit = squin.qubit.new(1)[0] + squin.h(qubit) + return squin.qubit.measure(qubit) + + coinflip.print() + + circuit = emit_circuit(coinflip, ignore_returns=True) + print(circuit) From 650cfb80c0bbce0920697434c003dee85a7966cb Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 1 Oct 2025 13:58:28 +0200 Subject: [PATCH 04/11] Make qubits hashable --- src/bloqade/cirq_utils/emit/base.py | 2 +- src/bloqade/cirq_utils/emit/qubit.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 4d664782..326ebd55 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -115,7 +115,7 @@ def main(): f"The method from which you're trying to emit a circuit takes {len(mt.args)} as input, but you passed in {len(args)} via the `args` keyword!" ) - emitter = EmitCirq(qubits=qubits) + emitter = EmitCirq(qubits=circuit_qubits) return emitter.run(mt, args=args) diff --git a/src/bloqade/cirq_utils/emit/qubit.py b/src/bloqade/cirq_utils/emit/qubit.py index 080921e7..47d736ab 100644 --- a/src/bloqade/cirq_utils/emit/qubit.py +++ b/src/bloqade/cirq_utils/emit/qubit.py @@ -14,11 +14,13 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New): n_qubits = frame.get(stmt.n_qubits) if frame.qubits is not None: - cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)] + cirq_qubits = tuple( + frame.qubits[i + frame.qubit_index] for i in range(n_qubits) + ) else: - cirq_qubits = [ + cirq_qubits = tuple( cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits) - ] + ) frame.qubit_index += n_qubits return (cirq_qubits,) From 6629ce43a5e7c21873a8ac86d55a75e6e66256e3 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 1 Oct 2025 14:03:31 +0200 Subject: [PATCH 05/11] Don't return the circuit --- src/bloqade/cirq_utils/emit/clifford.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/bloqade/cirq_utils/emit/clifford.py b/src/bloqade/cirq_utils/emit/clifford.py index 6eb16905..eb97a632 100644 --- a/src/bloqade/cirq_utils/emit/clifford.py +++ b/src/bloqade/cirq_utils/emit/clifford.py @@ -21,7 +21,7 @@ def hermitian( qubits = frame.get(stmt.qubits) cirq_op = getattr(cirq, stmt.name.upper()) frame.circuit.append(cirq_op.on_each(qubits)) - return (frame.circuit,) + return () @impl(clifford.stmts.S) @impl(clifford.stmts.T) @@ -37,7 +37,7 @@ def unitary( cirq_op = cirq_op ** (-1) frame.circuit.append(cirq_op.on_each(qubits)) - return (frame.circuit,) + return () @impl(clifford.stmts.SqrtX) @impl(clifford.stmts.SqrtY) @@ -59,7 +59,7 @@ def sqrt( cirq_op = cirq.YPowGate(exponent=exponent) frame.circuit.append(cirq_op.on_each(qubits)) - return (frame.circuit,) + return () @impl(clifford.stmts.CX) @impl(clifford.stmts.CY) @@ -72,7 +72,7 @@ def control( cirq_op = getattr(cirq, stmt.name.upper()) cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] frame.circuit.append(cirq_op.on_each(cirq_qubits)) - return (frame.circuit,) + return () @impl(clifford.stmts.Rx) @impl(clifford.stmts.Ry) @@ -87,4 +87,4 @@ def rot( cirq_op = getattr(cirq, stmt.name.title())(rads=angle) frame.circuit.append(cirq_op.on_each(qubits)) - return (frame.circuit,) + return () From 77cd5a87713288ed013426e6cee662a857ce17c7 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 1 Oct 2025 14:38:35 +0200 Subject: [PATCH 06/11] Basic lowering for circuit operations --- src/bloqade/cirq_utils/lowering.py | 35 +++++++++++++++++++++++++-- test/cirq_utils/test_cirq_to_squin.py | 9 ++----- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 53bce1f5..e41627cc 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -148,7 +148,14 @@ def main(): ) -CirqNode = cirq.Circuit | cirq.Moment | cirq.Gate | cirq.Qid | cirq.Operation +CirqNode = ( + cirq.Circuit + | cirq.FrozenCircuit + | cirq.Moment + | cirq.Gate + | cirq.Qid + | cirq.Operation +) DecomposeNode = ( cirq.SwapPowGate @@ -261,7 +268,9 @@ def lower_global( raise lowering.BuildError("Literals not supported in cirq circuit") def visit_Circuit( - self, state: lowering.State[cirq.Circuit], node: cirq.Circuit + self, + state: lowering.State[cirq.Circuit], + node: cirq.Circuit | cirq.FrozenCircuit, ) -> lowering.Result: for moment in node: self.visit_Moment(state, moment) @@ -516,6 +525,28 @@ def visit_ControlledOperation( target_qarg = self.lower_qubit_getindices(state, (target,)) return state.current_frame.push(stmt(control_qarg, target_qarg)) + def visit_FrozenCircuit( + self, state: lowering.State[cirq.Circuit], node: cirq.FrozenCircuit + ): + return self.visit_Circuit(state, node) + + def visit_CircuitOperation( + self, state: lowering.State[cirq.Circuit], node: cirq.CircuitOperation + ): + reps = node.repetitions + + if not isinstance(reps, int): + raise lowering.BuildError( + f"Cannot lower CircuitOperation with non-integer repetitions: {node}" + ) + + if reps > 1: + raise lowering.BuildError( + "Repetitions of circuit operatiosn not yet supported" + ) + + return self.visit(state, node.circuit) + # def visit_BitFlipChannel( # self, state: lowering.State[cirq.Circuit], node: cirq.BitFlipChannel # ): diff --git a/test/cirq_utils/test_cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py index 81a8de14..c7f1d3f2 100644 --- a/test/cirq_utils/test_cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -160,7 +160,7 @@ def nested_circuit(): cirq.CircuitOperation( cirq.Circuit(cirq.H(q[1]), cirq.CX(q[1], q[2])).freeze(), use_repetition_ids=False, - ).controlled_by(q[0]), + ), cirq.measure(*q), ) @@ -177,6 +177,7 @@ def nested_circuit(): two_qubit_pow_gates, swap_circuit, three_qubit_gates, + nested_circuit, # noise_channels, # depolarizing_channels, ], @@ -211,12 +212,6 @@ def test_return_register(): assert kernel.return_type.body.is_subseteq(ilist.IListType) -@pytest.mark.xfail -def test_nested_circuit(): - # TODO: lowering for CircuitOperation - test_circuit(nested_circuit) - - def test_passing_in_register(): circuit = pow_gate_circuit() print(circuit) From 059c0c8464eb00642b47d38a777a34aa140cdc53 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 2 Oct 2025 13:16:30 +0200 Subject: [PATCH 07/11] Implement lowering for cirq noise --- src/bloqade/cirq_utils/lowering.py | 173 +++++++++--------- ...cirq_to_squin.py => test_cirq_to_squin.py} | 20 +- 2 files changed, 101 insertions(+), 92 deletions(-) rename test/cirq_utils/{cirq_to_squin.py => test_cirq_to_squin.py} (96%) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index e41627cc..276f334c 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -6,7 +6,7 @@ from kirin.rewrite import Walk, CFGCompactify from kirin.dialects import py, scf, func, ilist -from bloqade.squin import qubit, kernel, clifford +from bloqade.squin import noise, qubit, kernel, clifford def load_circuit( @@ -180,6 +180,24 @@ class Squin(lowering.LoweringABC[cirq.Circuit]): qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict) next_qreg_index: int = field(init=False, default=0) + two_qubit_paulis = ( + "IX", + "IY", + "IZ", + "XI", + "XX", + "XY", + "XZ", + "YI", + "YX", + "YY", + "YZ", + "ZI", + "ZX", + "ZY", + "ZZ", + ) + def __post_init__(self): # TODO: sort by cirq ordering qbits = sorted(self.circuit.all_qubits()) @@ -547,90 +565,69 @@ def visit_CircuitOperation( return self.visit(state, node.circuit) - # def visit_BitFlipChannel( - # self, state: lowering.State[cirq.Circuit], node: cirq.BitFlipChannel - # ): - # x = state.current_frame.push(op.stmts.X()) - # p = state.current_frame.push(py.Constant(node.p)) - # return state.current_frame.push( - # noise.stmts.PauliError(basis=x.result, p=p.result) - # ) - - # def visit_AmplitudeDampingChannel( - # self, state: lowering.State[cirq.Circuit], node: cirq.AmplitudeDampingChannel - # ): - # r = state.current_frame.push(op.stmts.Reset()) - # p = state.current_frame.push(py.Constant(node.gamma)) - - # # TODO: do we need a dedicated noise stmt for this? Using PauliError - # # with this basis feels like a hack - # noise_channel = state.current_frame.push( - # noise.stmts.PauliError(basis=r.result, p=p.result) - # ) - - # return noise_channel - - # def visit_GeneralizedAmplitudeDampingChannel( - # self, - # state: lowering.State[cirq.Circuit], - # node: cirq.GeneralizedAmplitudeDampingChannel, - # ): - # p = state.current_frame.push(py.Constant(node.p)).result - # gamma = state.current_frame.push(py.Constant(node.gamma)).result - - # # NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel, - # # which basically means p is the probability of the environment being in the vacuum state - # prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result - # one_ = state.current_frame.push(py.Constant(1)).result - # p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result - # prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result - - # r0 = state.current_frame.push(op.stmts.Reset()).result - # r1 = state.current_frame.push(op.stmts.ResetToOne()).result - - # probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result - # ops = state.current_frame.push(ilist.New(values=(r0, r1))).result - - # noise_channel = state.current_frame.push( - # noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops) - # ) - - # return noise_channel - - # def visit_DepolarizingChannel( - # self, state: lowering.State[cirq.Circuit], node: cirq.DepolarizingChannel - # ): - # p = state.current_frame.push(py.Constant(node.p)).result - # return state.current_frame.push(noise.stmts.Depolarize(p)) - - # def visit_AsymmetricDepolarizingChannel( - # self, state: lowering.State[cirq.Circuit], node: cirq.AsymmetricDepolarizingChannel - # ): - # nqubits = node.num_qubits() - # if nqubits > 2: - # raise lowering.BuildError( - # "AsymmetricDepolarizingChannel applied to more than 2 qubits is not supported!" - # ) - - # if nqubits == 1: - # p_x = state.current_frame.push(py.Constant(node.p_x)).result - # p_y = state.current_frame.push(py.Constant(node.p_y)).result - # p_z = state.current_frame.push(py.Constant(node.p_z)).result - # params = state.current_frame.push(ilist.New(values=(p_x, p_y, p_z))).result - # return state.current_frame.push(noise.stmts.SingleQubitPauliChannel(params)) - - # # NOTE: nqubits == 2 - # error_probs = node.error_probabilities - # paulis = ("I", "X", "Y", "Z") - # values = [] - # for p1 in paulis: - # for p2 in paulis: - # if p1 == p2 == "I": - # continue - - # p = error_probs.get(p1 + p2, 0.0) - # p_ssa = state.current_frame.push(py.Constant(p)).result - # values.append(p_ssa) - - # params = state.current_frame.push(ilist.New(values=values)).result - # return state.current_frame.push(noise.stmts.TwoQubitPauliChannel(params)) + def visit_BitFlipChannel( + self, state: lowering.State[cirq.Circuit], node: cirq.BitFlipChannel + ): + p = node.gate.p + p_x = state.current_frame.push(py.Constant(p)).result + p_y = p_z = state.current_frame.push(py.Constant(0)).result + qubits = self.lower_qubit_getindices(state, node.qubits) + return state.current_frame.push( + noise.stmts.SingleQubitPauliChannel(px=p_x, py=p_y, pz=p_z, qubits=qubits) + ) + + def visit_DepolarizingChannel( + self, state: lowering.State[cirq.Circuit], node: cirq.DepolarizingChannel + ): + p = state.current_frame.push(py.Constant(node.gate.p)).result + qubits = self.lower_qubit_getindices(state, node.qubits) + return state.current_frame.push(noise.stmts.Depolarize(p, qubits=qubits)) + + def visit_AsymmetricDepolarizingChannel( + self, + state: lowering.State[cirq.Circuit], + node: cirq.AsymmetricDepolarizingChannel, + ): + nqubits = node.gate.num_qubits() + if nqubits > 2: + raise lowering.BuildError( + "AsymmetricDepolarizingChannel applied to more than 2 qubits is not supported!" + ) + + if nqubits == 1: + qubits = self.lower_qubit_getindices(state, node.qubits) + p_x = state.current_frame.push(py.Constant(node.gate.p_x)).result + p_y = state.current_frame.push(py.Constant(node.gate.p_y)).result + p_z = state.current_frame.push(py.Constant(node.gate.p_z)).result + return state.current_frame.push( + noise.stmts.SingleQubitPauliChannel(p_x, p_y, p_z, qubits) + ) + + # NOTE: nqubits == 2 + error_probs = node.gate.error_probabilities + probability_values = [] + p0 = None + for key in self.two_qubit_paulis: + p = error_probs.get(key) + + if p is None: + if p0 is None: + p0 = state.current_frame.push(py.Constant(0)).result + p_ssa = p0 + else: + p_ssa = state.current_frame.push(py.Constant(p)).result + probability_values.append(p_ssa) + + probabilities = state.current_frame.push( + ilist.New(values=probability_values) + ).result + + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) + + return state.current_frame.push( + noise.stmts.TwoQubitPauliChannel( + probabilities, controls=control_qarg, targets=target_qarg + ) + ) diff --git a/test/cirq_utils/cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py similarity index 96% rename from test/cirq_utils/cirq_to_squin.py rename to test/cirq_utils/test_cirq_to_squin.py index c7f1d3f2..6e0fc626 100644 --- a/test/cirq_utils/cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -129,15 +129,22 @@ def three_qubit_gates(): ) -def noise_channels(): +def bit_flip(): q = cirq.LineQubit(0) return cirq.Circuit( cirq.X(q), cirq.bit_flip(0.1).on(q), + ) + + +def amplitude_damping(): + q = cirq.LineQubit(0) + + # NOTE: currently not supported -- marked as xfail below + return cirq.Circuit( cirq.amplitude_damp(0.1).on(q), cirq.generalized_amplitude_damp(p=0.1, gamma=0.05).on(q), - cirq.measure(q), ) @@ -178,8 +185,8 @@ def nested_circuit(): swap_circuit, three_qubit_gates, nested_circuit, - # noise_channels, - # depolarizing_channels, + bit_flip, + depolarizing_channels, ], ) def test_circuit(circuit_f, run_sim: bool = False): @@ -402,3 +409,8 @@ def multi_arg(n: int, p: float): circuit = emit_circuit(multi_arg, args=(3, 0.1)) print(circuit) + + +@pytest.mark.xfail +def test_amplitude_damping(): + test_circuit(amplitude_damping) From 6806ccdbf02d53b3aeaf8dec35a36bb6ed954c4b Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 2 Oct 2025 13:54:10 +0200 Subject: [PATCH 08/11] Implement noise emit to cirq --- src/bloqade/cirq_utils/emit/noise.py | 90 ++-- ...to_cirq.py => test_squin_noise_to_cirq.py} | 0 test/cirq_utils/test_squin_to_cirq.py | 470 ------------------ 3 files changed, 49 insertions(+), 511 deletions(-) rename test/cirq_utils/{squin_noise_to_cirq.py => test_squin_noise_to_cirq.py} (100%) delete mode 100644 test/cirq_utils/test_squin_to_cirq.py diff --git a/src/bloqade/cirq_utils/emit/noise.py b/src/bloqade/cirq_utils/emit/noise.py index bd8f8369..a5721854 100644 --- a/src/bloqade/cirq_utils/emit/noise.py +++ b/src/bloqade/cirq_utils/emit/noise.py @@ -1,36 +1,53 @@ import cirq -from kirin.emit import EmitError from kirin.interp import MethodTable, impl from bloqade.squin import noise from .base import EmitCirq, EmitCirqFrame -from .runtime import ( - KronRuntime, - BasicOpRuntime, - OperatorRuntimeABC, - PauliStringRuntime, -) @noise.dialect.register(key="emit.cirq") class EmitCirqNoiseMethods(MethodTable): + two_qubit_paulis = ( + "IX", + "IY", + "IZ", + "XI", + "XX", + "XY", + "XZ", + "YI", + "YX", + "YY", + "YZ", + "ZI", + "ZX", + "ZY", + "ZZ", + ) + @impl(noise.stmts.Depolarize) def depolarize( self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize ): p = frame.get(stmt.p) - gate = cirq.depolarize(p, n_qubits=1) - return (BasicOpRuntime(gate=gate),) + qubits = frame.get(stmt.qubits) + cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits) + frame.circuit.append(cirfq_op) + return () @impl(noise.stmts.Depolarize2) def depolarize2( self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize2 ): p = frame.get(stmt.p) - gate = cirq.depolarize(p, n_qubits=2) - return (BasicOpRuntime(gate=gate),) + controls = frame.get(stmt.controls) + targets = frame.get(stmt.targets) + cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] + cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits) + frame.circuit.append(cirq_op) + return () @impl(noise.stmts.SingleQubitPauliChannel) def single_qubit_pauli_channel( @@ -39,9 +56,15 @@ def single_qubit_pauli_channel( frame: EmitCirqFrame, stmt: noise.stmts.SingleQubitPauliChannel, ): - ps = frame.get(stmt.params) - gate = cirq.asymmetric_depolarize(*ps) - return (BasicOpRuntime(gate=gate),) + px = frame.get(stmt.px) + py = frame.get(stmt.py) + pz = frame.get(stmt.pz) + qubits = frame.get(stmt.qubits) + + cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits) + frame.circuit.append(cirq_op) + + return () @impl(noise.stmts.TwoQubitPauliChannel) def two_qubit_pauli_channel( @@ -50,33 +73,18 @@ def two_qubit_pauli_channel( frame: EmitCirqFrame, stmt: noise.stmts.TwoQubitPauliChannel, ): - ps = frame.get(stmt.params) - paulis = ("I", "X", "Y", "Z") - pauli_combinations = [ - pauli1 + pauli2 - for pauli1 in paulis - for pauli2 in paulis - if not (pauli1 == pauli2 == "I") - ] - error_probabilities = {key: p for (key, p) in zip(pauli_combinations, ps)} - gate = cirq.asymmetric_depolarize(error_probabilities=error_probabilities) - return (BasicOpRuntime(gate),) - - @staticmethod - def _op_to_key(operator: OperatorRuntimeABC) -> str: - match operator: - case KronRuntime(): - key_lhs = EmitCirqNoiseMethods._op_to_key(operator.lhs) - key_rhs = EmitCirqNoiseMethods._op_to_key(operator.rhs) - return key_lhs + key_rhs + ps = frame.get(stmt.probabilities) + error_probabilities = { + key: p for (key, p) in zip(self.two_qubit_paulis, ps) if p != 0 + } - case BasicOpRuntime(): - return str(operator.gate) + controls = frame.get(stmt.controls) + targets = frame.get(stmt.targets) + cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] - case PauliStringRuntime(): - return operator.string + cirq_op = cirq.asymmetric_depolarize( + error_probabilities=error_probabilities + ).on_each(cirq_qubits) + frame.circuit.append(cirq_op) - case _: - raise EmitError( - f"Unexpected operator runtime in StochasticUnitaryChannel of type {type(operator).__name__} encountered!" - ) + return () diff --git a/test/cirq_utils/squin_noise_to_cirq.py b/test/cirq_utils/test_squin_noise_to_cirq.py similarity index 100% rename from test/cirq_utils/squin_noise_to_cirq.py rename to test/cirq_utils/test_squin_noise_to_cirq.py diff --git a/test/cirq_utils/test_squin_to_cirq.py b/test/cirq_utils/test_squin_to_cirq.py deleted file mode 100644 index f3cf2fd3..00000000 --- a/test/cirq_utils/test_squin_to_cirq.py +++ /dev/null @@ -1,470 +0,0 @@ -import math -import typing - -import cirq -import pytest -from kirin.emit import EmitError -from kirin.passes import inline -from kirin.dialects import ilist - -from bloqade import squin -from bloqade.cirq_utils import emit, emit_circuit - - -def test_pauli(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - q2 = squin.qubit.new(4) - x = squin.op.x() - y = squin.op.y() - z = squin.op.z() - squin.qubit.apply(x, q[0]) - squin.qubit.apply(y, q2[0]) - squin.qubit.apply(z, q2[3]) - - circuit = emit_circuit(main) - - print(circuit) - - qbits = circuit.all_qubits() - assert len(qbits) == 3 - assert isinstance(qbit := list(qbits)[-1], cirq.LineQubit) - assert qbit.x == 5 - - -@pytest.mark.parametrize("op_name", ["h", "s", "t", "x", "y", "z"]) -def test_basic_op(op_name: str): - @squin.kernel - def main(): - q = squin.qubit.new(1) - op_ = getattr(squin.op, op_name)() - squin.qubit.apply(op_, q) - - emit_circuit(main) - - -def test_control(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) - - circuit = emit_circuit(main) - - print(circuit) - - assert len(circuit) == 2 - assert circuit[1].operations[0].gate == cirq.CNOT - - -def test_custom_qubits(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) - - qubits = [cirq.GridQubit(0, 1), cirq.GridQubit(2, 2)] - circuit = emit_circuit(main, qubits=qubits) - - print(circuit) - - circuit_qubits = circuit.all_qubits() - assert len(circuit_qubits) == 2 - assert frozenset(qubits) == circuit_qubits - - -def test_composed_kernels(): - @squin.kernel - def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q_[0]) - - @squin.kernel - def main(): - q = squin.qubit.new(2) - sub_kernel(q) - - circuit = emit_circuit(main) - - print(circuit) - - assert len(circuit) == 1 - assert len(circuit[0].operations) == 1 - assert isinstance(circuit[0].operations[0], cirq.CircuitOperation) - - -def test_nested_kernels(): - @squin.kernel - def sub_kernel2(q2_: ilist.IList[squin.qubit.Qubit, typing.Any]): - cx = squin.op.control(squin.op.x(), n_controls=1) - squin.qubit.apply(cx, q2_[0], q2_[1]) - - @squin.kernel - def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q_[0]) - id = squin.op.identity(sites=1) - squin.qubit.apply(id, q_[1]) - sub_kernel2(q_) - - @squin.kernel - def main(): - q = squin.qubit.new(2) - sub_kernel(q) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_return_value(): - @squin.kernel - def sub_kernel(q: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q[0], q[1]) - return h - - @squin.kernel - def main(): - q = squin.qubit.new(2) - h_ = sub_kernel(q) - squin.qubit.apply(h_, q[1]) - - circuit = emit_circuit(main) - - print(circuit) - - with pytest.raises(EmitError): - emit_circuit(sub_kernel) - - @squin.kernel - def main2(): - q = squin.qubit.new(2) - h_ = sub_kernel(q) - squin.qubit.apply(h_, q[1]) - return h_ - - circuit2 = emit_circuit(main2, ignore_returns=True) - print(circuit2) - - assert circuit2 == circuit - - -def test_return_qubits(): - @squin.kernel - def sub_kernel(q: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - q2 = squin.qubit.new(3) - squin.qubit.apply(cx, [q[0], q2[2]]) - return q2 - - @squin.kernel - def main(): - q = squin.qubit.new(2) - q2_ = sub_kernel(q) - squin.qubit.apply(squin.op.x(), q2_[0]) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_measurement(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - y = squin.op.y() - squin.qubit.broadcast(y, q) - squin.qubit.measure(q) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_kron(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - x = squin.op.x() - xx = squin.op.kron(x, x) - squin.qubit.apply(xx, q) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_mult(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - x = squin.op.x() - y = squin.op.y() - m = squin.op.mult(x, y) - squin.qubit.apply(m, q[0]) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_projector(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.broadcast(h, q) - p0 = squin.op.p0() - p1 = squin.op.p1() - squin.qubit.apply(p0, q[0]) - squin.qubit.apply(p1, q[1]) - squin.qubit.measure(q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_sp_sn(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - sp = squin.op.spin_p() - sn = squin.op.spin_n() - squin.qubit.apply(sp, q) - squin.qubit.apply(sn, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_adjoint(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - s = squin.op.s() - s_dagger = squin.op.adjoint(s) - squin.qubit.apply(s, q) - squin.qubit.apply(s_dagger, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_u3(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - u3 = squin.op.u(0.323, 1.123, math.pi / 7) - squin.qubit.apply(u3, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_scale(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - x = squin.op.x() - s = 2 * x - squin.qubit.apply(s, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_phase(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - p = squin.op.phase(math.pi / 3) - squin.qubit.apply(p, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_shift(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - p = squin.op.shift(math.pi / 7) - squin.qubit.apply(p, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_reset(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - r = squin.op.reset() - squin.qubit.apply(r, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_pauli_string(): - @squin.kernel - def main(): - p = squin.op.pauli_string(string="XYZ") - q = squin.qubit.new(3) - squin.qubit.apply(p, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_invoke_cache(): - @squin.kernel - def sub_kernel(q_: squin.qubit.Qubit): - squin.qubit.apply(squin.op.h(), q_) - - @squin.kernel - def main(): - q = squin.qubit.new(2) - q0 = q[0] - sub_kernel(q0) - sub_kernel(q[1]) - sub_kernel(q0) - - target = emit.base.EmitCirq(main.dialects) - - circuit = target.run(main, ()) - - print(circuit) - - assert len(target._cached_circuit_operations) == 2 - - -def test_rot(): - @squin.kernel - def main(): - axis = squin.op.x() - q = squin.qubit.new(1) - r = squin.op.rot(axis=axis, angle=math.pi / 2) - squin.qubit.apply(r, q[0]) - - circuit = emit_circuit(main) - - print(circuit) - - assert circuit[0].operations[0].gate == cirq.Rx(rads=math.pi / 2) - - @squin.kernel - def main2(): - x = squin.op.x() - y = squin.op.y() - q = squin.qubit.new(1) - r = squin.op.rot(axis=x * y, angle=0.123) - squin.qubit.apply(r, q[0]) - - with pytest.raises(EmitError): - emit_circuit(main2) - - @squin.kernel - def main3(): - op = squin.op.h() - q = squin.qubit.new(1) - r = squin.op.rot(axis=op, angle=0.123) - squin.qubit.apply(r, q[0]) - - with pytest.raises(EmitError): - emit_circuit(main3) - - -def test_additional_stmts(): - @squin.kernel - def main(): - x = squin.op.x() - r = squin.op.rot(x, 0.123) - q = squin.qubit.new(3) - squin.qubit.apply(r, q[0]) - sqrt_x = squin.op.sqrt_x() - sqrt_y = squin.op.sqrt_y() - squin.qubit.apply(sqrt_x, q[1]) - squin.qubit.apply(sqrt_y, q[2]) - - main.print() - - circuit = emit_circuit(main) - - print(circuit) - - q = cirq.LineQubit.range(3) - expected_circuit = cirq.Circuit( - cirq.Rx(rads=0.123).on(q[0]), - cirq.X(q[1]) ** 0.5, - cirq.Y(q[2]) ** 0.5, - ) - - assert circuit == expected_circuit - - -def test_return_measurement(): - - @squin.kernel - def coinflip(): - qubit = squin.qubit.new(1)[0] - squin.gate.h(qubit) - return squin.qubit.measure(qubit) - - coinflip.print() - - circuit = emit_circuit(coinflip, ignore_returns=True) - print(circuit) - - -def test_reset_to_one(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - squin.gate.h(q[0]) - squin.gate.reset_to_one(q[0]) - - inline.InlinePass(main.dialects)(main) - - main.print() - - circuit = emit_circuit(main) - - q = cirq.LineQubit(0) - expected_circuit = cirq.Circuit( - cirq.H(q), - cirq.reset(q), - cirq.X(q), - ) - - print(circuit) - - assert circuit == expected_circuit - - -def test_overlapping_operations(): - @squin.kernel - def main(): - q = squin.qubit.new(5) - - x = squin.op.x() - y = squin.op.y() - op = x * y - - squin.qubit.broadcast(op, q) - - circuit = emit_circuit(main) - - print(circuit) From f931e476f9eb6e511be53550bc01f49a163f9c72 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 2 Oct 2025 13:55:32 +0200 Subject: [PATCH 09/11] Re-enable cirq util tests --- test/cirq_utils/noise/{noise_models.py => test_noise_models.py} | 0 test/cirq_utils/noise/{noisy_ghz.py => test_noisy_ghz.py} | 0 ...zone_correlated_noise.py => test_one_zone_correlated_noise.py} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename test/cirq_utils/noise/{noise_models.py => test_noise_models.py} (100%) rename test/cirq_utils/noise/{noisy_ghz.py => test_noisy_ghz.py} (100%) rename test/cirq_utils/noise/{one_zone_correlated_noise.py => test_one_zone_correlated_noise.py} (100%) diff --git a/test/cirq_utils/noise/noise_models.py b/test/cirq_utils/noise/test_noise_models.py similarity index 100% rename from test/cirq_utils/noise/noise_models.py rename to test/cirq_utils/noise/test_noise_models.py diff --git a/test/cirq_utils/noise/noisy_ghz.py b/test/cirq_utils/noise/test_noisy_ghz.py similarity index 100% rename from test/cirq_utils/noise/noisy_ghz.py rename to test/cirq_utils/noise/test_noisy_ghz.py diff --git a/test/cirq_utils/noise/one_zone_correlated_noise.py b/test/cirq_utils/noise/test_one_zone_correlated_noise.py similarity index 100% rename from test/cirq_utils/noise/one_zone_correlated_noise.py rename to test/cirq_utils/noise/test_one_zone_correlated_noise.py From c6bd3f911e05e0467e5d028c781198573181e09a Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 2 Oct 2025 14:10:23 +0200 Subject: [PATCH 10/11] Rename emit clifford class --- src/bloqade/cirq_utils/emit/clifford.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloqade/cirq_utils/emit/clifford.py b/src/bloqade/cirq_utils/emit/clifford.py index eb97a632..086ec5d3 100644 --- a/src/bloqade/cirq_utils/emit/clifford.py +++ b/src/bloqade/cirq_utils/emit/clifford.py @@ -9,7 +9,7 @@ @clifford.dialect.register(key="emit.cirq") -class EmitCirqOpMethods(MethodTable): +class EmitCirqCliffordMethods(MethodTable): @impl(clifford.stmts.X) @impl(clifford.stmts.Y) From 6ba3ccf8ac4cad082ed36edfb910dc1099ac1b43 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 7 Oct 2025 10:28:42 +0200 Subject: [PATCH 11/11] Update docstrings and mark method tables private --- src/bloqade/cirq_utils/emit/base.py | 26 ++++++++++--------------- src/bloqade/cirq_utils/emit/clifford.py | 2 +- src/bloqade/cirq_utils/emit/noise.py | 2 +- src/bloqade/cirq_utils/lowering.py | 12 ++++++++---- 4 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 326ebd55..e481cb4e 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -43,16 +43,15 @@ def emit_circuit( ```python from bloqade import squin + from bloqade.cirq_utils import emit_circuit @squin.kernel def main(): q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) + squin.h(q[0]) + squin.cx(q[0], q[1]) - circuit = squin.cirq.emit_circuit(main) + circuit = emit_circuit(main) print(circuit) ``` @@ -62,30 +61,25 @@ def main(): ```python from bloqade import squin + from bloqade.cirq_utils import emit_circuit from kirin.dialects import ilist from typing import Literal import cirq @squin.kernel def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]): - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) - return cx + squin.h(q[0]) + squin.cx(q[0], q[1]) @squin.kernel def main(): q = squin.qubit.new(2) - cx = entangle(q) - q2 = squin.qubit.new(3) - squin.qubit.apply(cx, [q[1], q2[2]]) - + entangle(q) # custom list of qubits on grid qubits = [cirq.GridQubit(i, i+1) for i in range(5)] - circuit = squin.cirq.emit_circuit(main, circuit_qubits=qubits) + circuit = emit_circuit(main, circuit_qubits=qubits) print(circuit) ``` @@ -184,7 +178,7 @@ def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit: @func.dialect.register(key="emit.cirq") -class FuncEmit(MethodTable): +class __FuncEmit(MethodTable): @impl(func.Function) def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function): diff --git a/src/bloqade/cirq_utils/emit/clifford.py b/src/bloqade/cirq_utils/emit/clifford.py index 086ec5d3..8e8d01f7 100644 --- a/src/bloqade/cirq_utils/emit/clifford.py +++ b/src/bloqade/cirq_utils/emit/clifford.py @@ -9,7 +9,7 @@ @clifford.dialect.register(key="emit.cirq") -class EmitCirqCliffordMethods(MethodTable): +class __EmitCirqCliffordMethods(MethodTable): @impl(clifford.stmts.X) @impl(clifford.stmts.Y) diff --git a/src/bloqade/cirq_utils/emit/noise.py b/src/bloqade/cirq_utils/emit/noise.py index a5721854..70476c93 100644 --- a/src/bloqade/cirq_utils/emit/noise.py +++ b/src/bloqade/cirq_utils/emit/noise.py @@ -7,7 +7,7 @@ @noise.dialect.register(key="emit.cirq") -class EmitCirqNoiseMethods(MethodTable): +class __EmitCirqNoiseMethods(MethodTable): two_qubit_paulis = ( "IX", diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 276f334c..39356bb9 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -51,7 +51,7 @@ def load_circuit( ```python # from cirq's "hello qubit" example import cirq - from bloqade import squin + from bloqade.cirq_utils import load_circuit # Pick a qubit. qubit = cirq.GridQubit(0, 0) @@ -63,7 +63,7 @@ def load_circuit( ) # load the circuit as squin - main = squin.load_circuit(circuit) + main = load_circuit(circuit) # print the resulting IR main.print() @@ -73,15 +73,19 @@ def load_circuit( and / or returning the respective quantum registers: ```python + import cirq + from bloqade.cirq_utils import load_circuit + from bloqade import squin + q = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q)) - get_entangled_qubits = squin.cirq.load_circuit( + get_entangled_qubits = load_circuit( circuit, return_register=True, kernel_name="get_entangled_qubits" ) get_entangled_qubits.print() - entangle_qubits = squin.cirq.load_circuit( + entangle_qubits = load_circuit( circuit, register_as_argument=True, kernel_name="entangle_qubits" )