diff --git a/src/bloqade/cirq_utils/emit/__init__.py b/src/bloqade/cirq_utils/emit/__init__.py index 098d868b..98caadcb 100644 --- a/src/bloqade/cirq_utils/emit/__init__.py +++ b/src/bloqade/cirq_utils/emit/__init__.py @@ -1,3 +1,3 @@ # NOTE: just to register methods -from . import op as op, noise as noise, qubit as qubit +from . import op as op, noise as noise, qubit as qubit, clifford as clifford from .base import emit_circuit as emit_circuit diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 7adfc829..e481cb4e 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -43,16 +43,15 @@ def emit_circuit( ```python from bloqade import squin + from bloqade.cirq_utils import emit_circuit @squin.kernel def main(): q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) + squin.h(q[0]) + squin.cx(q[0], q[1]) - circuit = squin.cirq.emit_circuit(main) + circuit = emit_circuit(main) print(circuit) ``` @@ -62,30 +61,25 @@ def main(): ```python from bloqade import squin + from bloqade.cirq_utils import emit_circuit from kirin.dialects import ilist from typing import Literal import cirq @squin.kernel def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]): - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) - return cx + squin.h(q[0]) + squin.cx(q[0], q[1]) @squin.kernel def main(): q = squin.qubit.new(2) - cx = entangle(q) - q2 = squin.qubit.new(3) - squin.qubit.apply(cx, [q[1], q2[2]]) - + entangle(q) # custom list of qubits on grid qubits = [cirq.GridQubit(i, i+1) for i in range(5)] - circuit = squin.cirq.emit_circuit(main, circuit_qubits=qubits) + circuit = emit_circuit(main, circuit_qubits=qubits) print(circuit) ``` @@ -115,7 +109,7 @@ def main(): f"The method from which you're trying to emit a circuit takes {len(mt.args)} as input, but you passed in {len(args)} via the `args` keyword!" ) - emitter = EmitCirq(qubits=qubits) + emitter = EmitCirq(qubits=circuit_qubits) return emitter.run(mt, args=args) @@ -137,7 +131,7 @@ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]): dialects: ir.DialectGroup = field(default_factory=_default_kernel) void = cirq.Circuit() qubits: Sequence[cirq.Qid] | None = None - _cached_circuit_operations: dict[int, cirq.CircuitOperation] = field( + _cached_invokes: dict[int, cirq.FrozenCircuit] = field( init=False, default_factory=dict ) @@ -184,7 +178,7 @@ def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit: @func.dialect.register(key="emit.cirq") -class FuncEmit(MethodTable): +class __FuncEmit(MethodTable): @impl(func.Function) def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function): @@ -193,12 +187,22 @@ def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function): @impl(func.Invoke) def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke): - stmt_hash = hash((stmt.callee, stmt.inputs)) - if ( - cached_circuit_op := emit._cached_circuit_operations.get(stmt_hash) - ) is not None: + try: + stmt_hash = hash( + (stmt.callee, tuple(frame.get(input) for input in stmt.inputs)) + ) + except (TypeError, interp.InterpreterError): + # NOTE: avoid unhashable types and missing keys, just don't cache them + stmt_hash = None + + if stmt_hash is not None: + cached_circuit = emit._cached_invokes.get(stmt_hash) + else: + cached_circuit = None + + if cached_circuit is not None: # NOTE: cache hit - frame.circuit.append(cached_circuit_op) + frame.circuit.append(cached_circuit.all_operations()) return () ret = stmt.result @@ -230,9 +234,8 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke): if return_stmt is not None: frame.entries[ret] = sub_frame.get(return_stmt.value) - circuit_op = cirq.CircuitOperation( - sub_circuit.freeze(), use_repetition_ids=False - ) - emit._cached_circuit_operations[stmt_hash] = circuit_op - frame.circuit.append(circuit_op) + if stmt_hash is not None: + emit._cached_invokes[stmt_hash] = sub_circuit.freeze() + + frame.circuit.append(sub_circuit.all_operations()) return () diff --git a/src/bloqade/cirq_utils/emit/clifford.py b/src/bloqade/cirq_utils/emit/clifford.py new file mode 100644 index 00000000..8e8d01f7 --- /dev/null +++ b/src/bloqade/cirq_utils/emit/clifford.py @@ -0,0 +1,90 @@ +import math + +import cirq +from kirin.interp import MethodTable, impl + +from bloqade.squin import clifford + +from .base import EmitCirq, EmitCirqFrame + + +@clifford.dialect.register(key="emit.cirq") +class __EmitCirqCliffordMethods(MethodTable): + + @impl(clifford.stmts.X) + @impl(clifford.stmts.Y) + @impl(clifford.stmts.Z) + @impl(clifford.stmts.H) + def hermitian( + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.SingleQubitGate + ): + qubits = frame.get(stmt.qubits) + cirq_op = getattr(cirq, stmt.name.upper()) + frame.circuit.append(cirq_op.on_each(qubits)) + return () + + @impl(clifford.stmts.S) + @impl(clifford.stmts.T) + def unitary( + self, + emit: EmitCirq, + frame: EmitCirqFrame, + stmt: clifford.stmts.SingleQubitNonHermitianGate, + ): + qubits = frame.get(stmt.qubits) + cirq_op = getattr(cirq, stmt.name.upper()) + if stmt.adjoint: + cirq_op = cirq_op ** (-1) + + frame.circuit.append(cirq_op.on_each(qubits)) + return () + + @impl(clifford.stmts.SqrtX) + @impl(clifford.stmts.SqrtY) + def sqrt( + self, + emit: EmitCirq, + frame: EmitCirqFrame, + stmt: clifford.stmts.SqrtX | clifford.stmts.SqrtY, + ): + qubits = frame.get(stmt.qubits) + + exponent = 0.5 + if stmt.adjoint: + exponent *= -1 + + if isinstance(stmt, clifford.stmts.SqrtX): + cirq_op = cirq.XPowGate(exponent=exponent) + else: + cirq_op = cirq.YPowGate(exponent=exponent) + + frame.circuit.append(cirq_op.on_each(qubits)) + return () + + @impl(clifford.stmts.CX) + @impl(clifford.stmts.CY) + @impl(clifford.stmts.CZ) + def control( + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.ControlledGate + ): + controls = frame.get(stmt.controls) + targets = frame.get(stmt.targets) + cirq_op = getattr(cirq, stmt.name.upper()) + cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] + frame.circuit.append(cirq_op.on_each(cirq_qubits)) + return () + + @impl(clifford.stmts.Rx) + @impl(clifford.stmts.Ry) + @impl(clifford.stmts.Rz) + def rot( + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.RotationGate + ): + qubits = frame.get(stmt.qubits) + + turns = frame.get(stmt.angle) + angle = turns * 2 * math.pi + cirq_op = getattr(cirq, stmt.name.title())(rads=angle) + + frame.circuit.append(cirq_op.on_each(qubits)) + return () diff --git a/src/bloqade/cirq_utils/emit/noise.py b/src/bloqade/cirq_utils/emit/noise.py index bd8f8369..70476c93 100644 --- a/src/bloqade/cirq_utils/emit/noise.py +++ b/src/bloqade/cirq_utils/emit/noise.py @@ -1,36 +1,53 @@ import cirq -from kirin.emit import EmitError from kirin.interp import MethodTable, impl from bloqade.squin import noise from .base import EmitCirq, EmitCirqFrame -from .runtime import ( - KronRuntime, - BasicOpRuntime, - OperatorRuntimeABC, - PauliStringRuntime, -) @noise.dialect.register(key="emit.cirq") -class EmitCirqNoiseMethods(MethodTable): +class __EmitCirqNoiseMethods(MethodTable): + + two_qubit_paulis = ( + "IX", + "IY", + "IZ", + "XI", + "XX", + "XY", + "XZ", + "YI", + "YX", + "YY", + "YZ", + "ZI", + "ZX", + "ZY", + "ZZ", + ) @impl(noise.stmts.Depolarize) def depolarize( self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize ): p = frame.get(stmt.p) - gate = cirq.depolarize(p, n_qubits=1) - return (BasicOpRuntime(gate=gate),) + qubits = frame.get(stmt.qubits) + cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits) + frame.circuit.append(cirfq_op) + return () @impl(noise.stmts.Depolarize2) def depolarize2( self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize2 ): p = frame.get(stmt.p) - gate = cirq.depolarize(p, n_qubits=2) - return (BasicOpRuntime(gate=gate),) + controls = frame.get(stmt.controls) + targets = frame.get(stmt.targets) + cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] + cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits) + frame.circuit.append(cirq_op) + return () @impl(noise.stmts.SingleQubitPauliChannel) def single_qubit_pauli_channel( @@ -39,9 +56,15 @@ def single_qubit_pauli_channel( frame: EmitCirqFrame, stmt: noise.stmts.SingleQubitPauliChannel, ): - ps = frame.get(stmt.params) - gate = cirq.asymmetric_depolarize(*ps) - return (BasicOpRuntime(gate=gate),) + px = frame.get(stmt.px) + py = frame.get(stmt.py) + pz = frame.get(stmt.pz) + qubits = frame.get(stmt.qubits) + + cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits) + frame.circuit.append(cirq_op) + + return () @impl(noise.stmts.TwoQubitPauliChannel) def two_qubit_pauli_channel( @@ -50,33 +73,18 @@ def two_qubit_pauli_channel( frame: EmitCirqFrame, stmt: noise.stmts.TwoQubitPauliChannel, ): - ps = frame.get(stmt.params) - paulis = ("I", "X", "Y", "Z") - pauli_combinations = [ - pauli1 + pauli2 - for pauli1 in paulis - for pauli2 in paulis - if not (pauli1 == pauli2 == "I") - ] - error_probabilities = {key: p for (key, p) in zip(pauli_combinations, ps)} - gate = cirq.asymmetric_depolarize(error_probabilities=error_probabilities) - return (BasicOpRuntime(gate),) - - @staticmethod - def _op_to_key(operator: OperatorRuntimeABC) -> str: - match operator: - case KronRuntime(): - key_lhs = EmitCirqNoiseMethods._op_to_key(operator.lhs) - key_rhs = EmitCirqNoiseMethods._op_to_key(operator.rhs) - return key_lhs + key_rhs + ps = frame.get(stmt.probabilities) + error_probabilities = { + key: p for (key, p) in zip(self.two_qubit_paulis, ps) if p != 0 + } - case BasicOpRuntime(): - return str(operator.gate) + controls = frame.get(stmt.controls) + targets = frame.get(stmt.targets) + cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)] - case PauliStringRuntime(): - return operator.string + cirq_op = cirq.asymmetric_depolarize( + error_probabilities=error_probabilities + ).on_each(cirq_qubits) + frame.circuit.append(cirq_op) - case _: - raise EmitError( - f"Unexpected operator runtime in StochasticUnitaryChannel of type {type(operator).__name__} encountered!" - ) + return () diff --git a/src/bloqade/cirq_utils/emit/qubit.py b/src/bloqade/cirq_utils/emit/qubit.py index 080921e7..47d736ab 100644 --- a/src/bloqade/cirq_utils/emit/qubit.py +++ b/src/bloqade/cirq_utils/emit/qubit.py @@ -14,11 +14,13 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New): n_qubits = frame.get(stmt.n_qubits) if frame.qubits is not None: - cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)] + cirq_qubits = tuple( + frame.qubits[i + frame.qubit_index] for i in range(n_qubits) + ) else: - cirq_qubits = [ + cirq_qubits = tuple( cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits) - ] + ) frame.qubit_index += n_qubits return (cirq_qubits,) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 76a706f5..39356bb9 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -1,4 +1,3 @@ -import math from typing import Any from dataclasses import field, dataclass @@ -7,7 +6,7 @@ from kirin.rewrite import Walk, CFGCompactify from kirin.dialects import py, scf, func, ilist -from bloqade.squin import op, noise, qubit, kernel +from bloqade.squin import noise, qubit, kernel, clifford def load_circuit( @@ -52,7 +51,7 @@ def load_circuit( ```python # from cirq's "hello qubit" example import cirq - from bloqade import squin + from bloqade.cirq_utils import load_circuit # Pick a qubit. qubit = cirq.GridQubit(0, 0) @@ -64,7 +63,7 @@ def load_circuit( ) # load the circuit as squin - main = squin.load_circuit(circuit) + main = load_circuit(circuit) # print the resulting IR main.print() @@ -74,15 +73,19 @@ def load_circuit( and / or returning the respective quantum registers: ```python + import cirq + from bloqade.cirq_utils import load_circuit + from bloqade import squin + q = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q)) - get_entangled_qubits = squin.cirq.load_circuit( + get_entangled_qubits = load_circuit( circuit, return_register=True, kernel_name="get_entangled_qubits" ) get_entangled_qubits.print() - entangle_qubits = squin.cirq.load_circuit( + entangle_qubits = load_circuit( circuit, register_as_argument=True, kernel_name="entangle_qubits" ) @@ -149,7 +152,14 @@ def main(): ) -CirqNode = cirq.Circuit | cirq.Moment | cirq.Gate | cirq.Qid | cirq.Operation +CirqNode = ( + cirq.Circuit + | cirq.FrozenCircuit + | cirq.Moment + | cirq.Gate + | cirq.Qid + | cirq.Operation +) DecomposeNode = ( cirq.SwapPowGate @@ -157,11 +167,16 @@ def main(): | cirq.PhasedXPowGate | cirq.PhasedXZGate | cirq.CSwapGate + | cirq.XXPowGate + | cirq.YYPowGate + | cirq.ZZPowGate + | cirq.CCXPowGate + | cirq.CCZPowGate ) @dataclass -class Squin(lowering.LoweringABC[CirqNode]): +class Squin(lowering.LoweringABC[cirq.Circuit]): """Lower a cirq.Circuit object to a squin kernel""" circuit: cirq.Circuit @@ -169,26 +184,45 @@ class Squin(lowering.LoweringABC[CirqNode]): qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict) next_qreg_index: int = field(init=False, default=0) + two_qubit_paulis = ( + "IX", + "IY", + "IZ", + "XI", + "XX", + "XY", + "XZ", + "YI", + "YX", + "YY", + "YZ", + "ZI", + "ZX", + "ZY", + "ZZ", + ) + def __post_init__(self): # TODO: sort by cirq ordering qbits = sorted(self.circuit.all_qubits()) self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)} - def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid): + def lower_qubit_getindex(self, state: lowering.State[cirq.Circuit], qid: cirq.Qid): index = self.qreg_index[qid] index_ssa = state.current_frame.push(py.Constant(index)).result qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa)) return qbit_getitem.result def lower_qubit_getindices( - self, state: lowering.State[CirqNode], qids: list[cirq.Qid] + self, state: lowering.State[cirq.Circuit], qids: tuple[cirq.Qid, ...] ): qbits_getitem = [self.lower_qubit_getindex(state, qid) for qid in qids] - return tuple(qbits_getitem) + qbits = state.current_frame.push(ilist.New(values=qbits_getitem)) + return qbits.result def run( self, - stmt: CirqNode, + stmt: cirq.Circuit, *, source: str | None = None, globals: dict[str, Any] | None = None, @@ -232,83 +266,66 @@ def run( return region - def visit(self, state: lowering.State[CirqNode], node: CirqNode) -> lowering.Result: + def visit( + self, state: lowering.State[cirq.Circuit], node: CirqNode + ) -> lowering.Result: name = node.__class__.__name__ return getattr(self, f"visit_{name}", self.generic_visit)(state, node) - def generic_visit(self, state: lowering.State[CirqNode], node: CirqNode): + def generic_visit(self, state: lowering.State[cirq.Circuit], node: CirqNode): if isinstance(node, CirqNode): raise lowering.BuildError( f"Cannot lower {node.__class__.__name__} node: {node}" ) - raise lowering.BuildError( - f"Unexpected `{node.__class__.__name__}` node: {repr(node)} is not an AST node" - ) + raise lowering.BuildError(f"Cannot lower {node}") - def lower_literal(self, state: lowering.State[CirqNode], value) -> ir.SSAValue: + # return self.visit_Operation(state, node) + + def lower_literal(self, state: lowering.State[cirq.Circuit], value) -> ir.SSAValue: raise lowering.BuildError("Literals not supported in cirq circuit") def lower_global( - self, state: lowering.State[CirqNode], node: CirqNode + self, state: lowering.State[cirq.Circuit], node: CirqNode ) -> lowering.LoweringABC.Result: raise lowering.BuildError("Literals not supported in cirq circuit") def visit_Circuit( - self, state: lowering.State[CirqNode], node: cirq.Circuit + self, + state: lowering.State[cirq.Circuit], + node: cirq.Circuit | cirq.FrozenCircuit, ) -> lowering.Result: for moment in node: - state.lower(moment) + self.visit_Moment(state, moment) def visit_Moment( - self, state: lowering.State[CirqNode], node: cirq.Moment + self, state: lowering.State[cirq.Circuit], node: cirq.Moment ) -> lowering.Result: for op_ in node.operations: - state.lower(op_) + self.visit(state, op_) def visit_GateOperation( - self, state: lowering.State[CirqNode], node: cirq.GateOperation + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation ): - if isinstance(node.gate, cirq.MeasurementGate): - # NOTE: special dispatch here, since measurement is a gate + a qubit in cirq, - # but a single statement in squin - return self.lower_measurement(state, node) - if isinstance(node.gate, DecomposeNode): # NOTE: easier to decompose these, but for that we need the qubits too, # so we need to do this within this method for subnode in cirq.decompose_once(node): - state.lower(subnode) + self.visit(state, subnode) return - op_ = state.lower(node.gate).expect_one() - qbits = self.lower_qubit_getindices(state, node.qubits) - return state.current_frame.push(qubit.Apply(operator=op_, qubits=qbits)) + # NOTE: just forward to the appropriate method by getting the name + name = node.gate.__class__.__name__ + return getattr(self, f"visit_{name}", self.generic_visit)(state, node) def visit_TaggedOperation( - self, state: lowering.State[CirqNode], node: cirq.TaggedOperation + self, state: lowering.State[cirq.Circuit], node: cirq.TaggedOperation ): - state.lower(node.untagged) - - def lower_measurement( - self, state: lowering.State[CirqNode], node: cirq.GateOperation - ): - if len(node.qubits) == 1: - qbit = self.lower_qubit_getindex(state, node.qubits[0]) - stmt = state.current_frame.push(qubit.MeasureQubit(qbit)) - else: - qbits = self.lower_qubit_getindices(state, node.qubits) - qbits_list = state.current_frame.push(ilist.New(values=qbits)) - stmt = state.current_frame.push(qubit.MeasureQubitList(qbits_list.result)) - - key = node.gate.key - if isinstance(key, cirq.MeasurementKey): - key = key.name - - state.current_frame.defs[key] = stmt.result - return stmt + return self.visit(state, node.untagged) def visit_ClassicallyControlledOperation( - self, state: lowering.State[CirqNode], node: cirq.ClassicallyControlledOperation + self, + state: lowering.State[cirq.Circuit], + node: cirq.ClassicallyControlledOperation, ): conditions: list[ir.SSAValue] = [] for outcome in node.classical_controls: @@ -369,211 +386,252 @@ def bool_op_or(x: bool, y: bool) -> bool: return state.current_frame.push(scf.IfElse(condition, then_body=then_body)) + def visit_MeasurementGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + cirq_qubits = node.qubits + if len(cirq_qubits) == 1: + qbit = self.lower_qubit_getindex(state, node.qubits[0]) + stmt = state.current_frame.push(qubit.MeasureQubit(qbit)) + else: + qubits = self.lower_qubit_getindices(state, node.qubits) + stmt = state.current_frame.push(qubit.MeasureQubitList(qubits)) + + # NOTE: add for classically controlled lowering + key = node.gate.key + if isinstance(key, cirq.MeasurementKey): + key = key.name + state.current_frame.defs[key] = stmt.result + + return stmt + def visit_SingleQubitPauliStringGateOperation( self, - state: lowering.State[CirqNode], + state: lowering.State[cirq.Circuit], node: cirq.SingleQubitPauliStringGateOperation, ): + if isinstance(node.pauli, cirq.IdentityGate): + # TODO: do we need an identity gate in clifford? + return + qargs = self.lower_qubit_getindices(state, (node.qubit,)) match node.pauli: case cirq.X: - op_ = op.stmts.X() + clifford_stmt = clifford.stmts.X case cirq.Y: - op_ = op.stmts.Y() + clifford_stmt = clifford.stmts.Y case cirq.Z: - op_ = op.stmts.Z() - case cirq.I: - op_ = op.stmts.Identity(sites=1) + clifford_stmt = clifford.stmts.Z case _: raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}") - state.current_frame.push(op_) - qargs = self.lower_qubit_getindices(state, [node.qubit]) - return state.current_frame.push(qubit.Apply(op_.result, qargs)) - - def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate): - if abs(node.exponent) == 1: - return state.current_frame.push(op.stmts.H()) - - # NOTE: decompose into products of paulis for arbitrary exponents according to _decompose_ method - # can't use decompose directly since that method requires qubits to be passed in for some reason - y_rhs = state.lower(cirq.YPowGate(exponent=0.25)).expect_one() - x = state.lower( - cirq.XPowGate(exponent=node.exponent, global_shift=node.global_shift) - ).expect_one() - y_lhs = state.lower(cirq.YPowGate(exponent=-0.25)).expect_one() + return state.current_frame.push(clifford_stmt(qargs)) - # NOTE: reversed order since we're creating a mult stmt - m_lhs = state.current_frame.push(op.stmts.Mult(y_lhs, x)) - return state.current_frame.push(op.stmts.Mult(m_lhs.result, y_rhs)) + def visit_HPowGate(self, state: lowering.State[cirq.Circuit], node: cirq.HPowGate): + qargs = self.lower_qubit_getindices(state, node.qubits) - def visit_XPowGate(self, state: lowering.State[CirqNode], node: cirq.XPowGate): - if abs(node.exponent == 1): - return state.current_frame.push(op.stmts.X()) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.H(qargs)) - return self.visit(state, node.in_su2()) + # NOTE: decompose into products of paulis for arbitrary exponents according to _decompose_ method + for subnode in cirq.decompose_once(node): + self.visit(state, subnode) - def visit_YPowGate(self, state: lowering.State[CirqNode], node: cirq.YPowGate): - if abs(node.exponent == 1): - return state.current_frame.push(op.stmts.Y()) + def visit_XPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + qargs = self.lower_qubit_getindices(state, node.qubits) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.X(qargs)) - return self.visit(state, node.in_su2()) + angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rx(angle.result, qargs)) - def visit_ZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZPowGate): - if node.exponent == 0.5: - return state.current_frame.push(op.stmts.S()) + def visit_YPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + qargs = self.lower_qubit_getindices(state, node.qubits) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.Y(qargs)) - if node.exponent == 0.25: - return state.current_frame.push(op.stmts.T()) + angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Ry(angle.result, qargs)) - if abs(node.exponent == 1): - return state.current_frame.push(op.stmts.Z()) + def visit_ZPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + qargs = self.lower_qubit_getindices(state, node.qubits) - # NOTE: just for the Z gate, an arbitrary exponent is equivalent to the ShiftOp - # up to a minus sign! - t = -node.exponent - theta = state.current_frame.push(py.Constant(math.pi * t)) - return state.current_frame.push(op.stmts.ShiftOp(theta=theta.result)) + if abs(node.gate.exponent) == 0.5: + adjoint = node.gate.exponent < 0 + return state.current_frame.push( + clifford.stmts.S(adjoint=adjoint, qubits=qargs) + ) - def visit_Rx(self, state: lowering.State[CirqNode], node: cirq.Rx): - x = state.current_frame.push(op.stmts.X()) - angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent)) - return state.current_frame.push(op.stmts.Rot(axis=x.result, angle=angle.result)) + if abs(node.gate.exponent) == 0.25: + adjoint = node.gate.exponent < 0 + return state.current_frame.push( + clifford.stmts.T(adjoint=adjoint, qubits=qargs) + ) - def visit_Ry(self, state: lowering.State[CirqNode], node: cirq.Ry): - y = state.current_frame.push(op.stmts.Y()) - angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent)) - return state.current_frame.push(op.stmts.Rot(axis=y.result, angle=angle.result)) + if node.gate.exponent % 2 == 1: + return state.current_frame.push(clifford.stmts.Z(qubits=qargs)) - def visit_Rz(self, state: lowering.State[CirqNode], node: cirq.Rz): - z = state.current_frame.push(op.stmts.Z()) - angle = state.current_frame.push(py.Constant(value=math.pi * node.exponent)) - return state.current_frame.push(op.stmts.Rot(axis=z.result, angle=angle.result)) + angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rz(angle.result, qargs)) - def visit_CXPowGate(self, state: lowering.State[CirqNode], node: cirq.CXPowGate): - x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(x, n_controls=1)) + def visit_Rx(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation): + qargs = self.lower_qubit_getindices(state, node.qubits) + angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rx(angle.result, qargs)) - def visit_CZPowGate(self, state: lowering.State[CirqNode], node: cirq.CZPowGate): - z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(z, n_controls=1)) + def visit_Ry(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation): + qargs = self.lower_qubit_getindices(state, node.qubits) + angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Ry(angle.result, qargs)) - def visit_ControlledOperation( - self, state: lowering.State[CirqNode], node: cirq.ControlledOperation - ): - return self.visit_GateOperation(state, node) + def visit_Rz(self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation): + qargs = self.lower_qubit_getindices(state, node.qubits) + angle = state.current_frame.push(py.Constant(value=0.5 * node.gate.exponent)) + return state.current_frame.push(clifford.stmts.Rz(angle.result, qargs)) - def visit_ControlledGate( - self, state: lowering.State[CirqNode], node: cirq.ControlledGate + def visit_CXPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation ): - op_ = state.lower(node.sub_gate).expect_one() - n_controls = node.num_controls() - return state.current_frame.push(op.stmts.Control(op_, n_controls=n_controls)) - - def visit_XXPowGate(self, state: lowering.State[CirqNode], node: cirq.XXPowGate): - x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Kron(x, x)) + if node.gate.exponent % 2 == 0: + return - def visit_YYPowGate(self, state: lowering.State[CirqNode], node: cirq.YYPowGate): - y = state.lower(cirq.YPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Kron(y, y)) + if node.gate.exponent % 2 != 1: + raise lowering.BuildError("Exponents of CX gate are not supported!") - def visit_ZZPowGate(self, state: lowering.State[CirqNode], node: cirq.ZZPowGate): - z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Kron(z, z)) + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) + return state.current_frame.push( + clifford.stmts.CX(controls=control_qarg, targets=target_qarg) + ) - def visit_CCXPowGate(self, state: lowering.State[CirqNode], node: cirq.CCXPowGate): - x = state.lower(cirq.XPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(x, n_controls=2)) + def visit_CZPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + if node.gate.exponent % 2 == 0: + return - def visit_CCZPowGate(self, state: lowering.State[CirqNode], node: cirq.CCZPowGate): - z = state.lower(cirq.ZPowGate(exponent=node.exponent)).expect_one() - return state.current_frame.push(op.stmts.Control(z, n_controls=2)) + if node.gate.exponent % 2 != 1: + raise lowering.BuildError("Exponents of CZ gate are not supported!") - def visit_BitFlipChannel( - self, state: lowering.State[CirqNode], node: cirq.BitFlipChannel - ): - x = state.current_frame.push(op.stmts.X()) - p = state.current_frame.push(py.Constant(node.p)) + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) return state.current_frame.push( - noise.stmts.PauliError(basis=x.result, p=p.result) + clifford.stmts.CZ(controls=control_qarg, targets=target_qarg) ) - def visit_AmplitudeDampingChannel( - self, state: lowering.State[CirqNode], node: cirq.AmplitudeDampingChannel + def visit_ControlledOperation( + self, state: lowering.State[cirq.Circuit], node: cirq.ControlledOperation ): - r = state.current_frame.push(op.stmts.Reset()) - p = state.current_frame.push(py.Constant(node.gamma)) + match node.gate.sub_gate: + case cirq.X: + stmt = clifford.stmts.CX + case cirq.Y: + stmt = clifford.stmts.CY + case cirq.Z: + stmt = clifford.stmts.CZ + case _: + raise lowering.BuildError( + f"Cannot lowering controlled operation: {node}" + ) - # TODO: do we need a dedicated noise stmt for this? Using PauliError - # with this basis feels like a hack - noise_channel = state.current_frame.push( - noise.stmts.PauliError(basis=r.result, p=p.result) - ) + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) + return state.current_frame.push(stmt(control_qarg, target_qarg)) - return noise_channel + def visit_FrozenCircuit( + self, state: lowering.State[cirq.Circuit], node: cirq.FrozenCircuit + ): + return self.visit_Circuit(state, node) - def visit_GeneralizedAmplitudeDampingChannel( - self, - state: lowering.State[CirqNode], - node: cirq.GeneralizedAmplitudeDampingChannel, + def visit_CircuitOperation( + self, state: lowering.State[cirq.Circuit], node: cirq.CircuitOperation ): - p = state.current_frame.push(py.Constant(node.p)).result - gamma = state.current_frame.push(py.Constant(node.gamma)).result + reps = node.repetitions - # NOTE: cirq has a weird convention here: if p == 1, we have AmplitudeDampingChannel, - # which basically means p is the probability of the environment being in the vacuum state - prob0 = state.current_frame.push(py.binop.Mult(p, gamma)).result - one_ = state.current_frame.push(py.Constant(1)).result - p_minus_1 = state.current_frame.push(py.binop.Sub(one_, p)).result - prob1 = state.current_frame.push(py.binop.Mult(p_minus_1, gamma)).result + if not isinstance(reps, int): + raise lowering.BuildError( + f"Cannot lower CircuitOperation with non-integer repetitions: {node}" + ) - r0 = state.current_frame.push(op.stmts.Reset()).result - r1 = state.current_frame.push(op.stmts.ResetToOne()).result + if reps > 1: + raise lowering.BuildError( + "Repetitions of circuit operatiosn not yet supported" + ) - probs = state.current_frame.push(ilist.New(values=(prob0, prob1))).result - ops = state.current_frame.push(ilist.New(values=(r0, r1))).result + return self.visit(state, node.circuit) - noise_channel = state.current_frame.push( - noise.stmts.StochasticUnitaryChannel(probabilities=probs, operators=ops) + def visit_BitFlipChannel( + self, state: lowering.State[cirq.Circuit], node: cirq.BitFlipChannel + ): + p = node.gate.p + p_x = state.current_frame.push(py.Constant(p)).result + p_y = p_z = state.current_frame.push(py.Constant(0)).result + qubits = self.lower_qubit_getindices(state, node.qubits) + return state.current_frame.push( + noise.stmts.SingleQubitPauliChannel(px=p_x, py=p_y, pz=p_z, qubits=qubits) ) - return noise_channel - def visit_DepolarizingChannel( - self, state: lowering.State[CirqNode], node: cirq.DepolarizingChannel + self, state: lowering.State[cirq.Circuit], node: cirq.DepolarizingChannel ): - p = state.current_frame.push(py.Constant(node.p)).result - return state.current_frame.push(noise.stmts.Depolarize(p)) + p = state.current_frame.push(py.Constant(node.gate.p)).result + qubits = self.lower_qubit_getindices(state, node.qubits) + return state.current_frame.push(noise.stmts.Depolarize(p, qubits=qubits)) def visit_AsymmetricDepolarizingChannel( - self, state: lowering.State[CirqNode], node: cirq.AsymmetricDepolarizingChannel + self, + state: lowering.State[cirq.Circuit], + node: cirq.AsymmetricDepolarizingChannel, ): - nqubits = node.num_qubits() + nqubits = node.gate.num_qubits() if nqubits > 2: raise lowering.BuildError( "AsymmetricDepolarizingChannel applied to more than 2 qubits is not supported!" ) if nqubits == 1: - p_x = state.current_frame.push(py.Constant(node.p_x)).result - p_y = state.current_frame.push(py.Constant(node.p_y)).result - p_z = state.current_frame.push(py.Constant(node.p_z)).result - params = state.current_frame.push(ilist.New(values=(p_x, p_y, p_z))).result - return state.current_frame.push(noise.stmts.SingleQubitPauliChannel(params)) + qubits = self.lower_qubit_getindices(state, node.qubits) + p_x = state.current_frame.push(py.Constant(node.gate.p_x)).result + p_y = state.current_frame.push(py.Constant(node.gate.p_y)).result + p_z = state.current_frame.push(py.Constant(node.gate.p_z)).result + return state.current_frame.push( + noise.stmts.SingleQubitPauliChannel(p_x, p_y, p_z, qubits) + ) # NOTE: nqubits == 2 - error_probs = node.error_probabilities - paulis = ("I", "X", "Y", "Z") - values = [] - for p1 in paulis: - for p2 in paulis: - if p1 == p2 == "I": - continue - - p = error_probs.get(p1 + p2, 0.0) + error_probs = node.gate.error_probabilities + probability_values = [] + p0 = None + for key in self.two_qubit_paulis: + p = error_probs.get(key) + + if p is None: + if p0 is None: + p0 = state.current_frame.push(py.Constant(0)).result + p_ssa = p0 + else: p_ssa = state.current_frame.push(py.Constant(p)).result - values.append(p_ssa) + probability_values.append(p_ssa) - params = state.current_frame.push(ilist.New(values=values)).result - return state.current_frame.push(noise.stmts.TwoQubitPauliChannel(params)) + probabilities = state.current_frame.push( + ilist.New(values=probability_values) + ).result + + control, target = node.qubits + control_qarg = self.lower_qubit_getindices(state, (control,)) + target_qarg = self.lower_qubit_getindices(state, (target,)) + + return state.current_frame.push( + noise.stmts.TwoQubitPauliChannel( + probabilities, controls=control_qarg, targets=target_qarg + ) + ) diff --git a/test/cirq_utils/noise/noise_models.py b/test/cirq_utils/noise/test_noise_models.py similarity index 100% rename from test/cirq_utils/noise/noise_models.py rename to test/cirq_utils/noise/test_noise_models.py diff --git a/test/cirq_utils/noise/noisy_ghz.py b/test/cirq_utils/noise/test_noisy_ghz.py similarity index 100% rename from test/cirq_utils/noise/noisy_ghz.py rename to test/cirq_utils/noise/test_noisy_ghz.py diff --git a/test/cirq_utils/noise/one_zone_correlated_noise.py b/test/cirq_utils/noise/test_one_zone_correlated_noise.py similarity index 100% rename from test/cirq_utils/noise/one_zone_correlated_noise.py rename to test/cirq_utils/noise/test_one_zone_correlated_noise.py diff --git a/test/cirq_utils/cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py similarity index 96% rename from test/cirq_utils/cirq_to_squin.py rename to test/cirq_utils/test_cirq_to_squin.py index 6da8fe0b..6e0fc626 100644 --- a/test/cirq_utils/cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -44,7 +44,7 @@ def controlled_gates(): return cirq.Circuit( cirq.H(q1), cirq.X(q0).controlled_by(q1), - cirq.Rx(rads=math.pi / 4).on(q0).controlled_by(q1), + cirq.Y.on(q0).controlled_by(q1), cirq.measure(q0, q1), ) @@ -90,7 +90,7 @@ def two_qubit_pow_gates(): q1 = cirq.LineQubit(1) return cirq.Circuit( - cirq.CX(q0, q1) ** 2, cirq.CZ(q0, q1) ** 0.123, cirq.measure(q0, q1) + cirq.CX(q0, q1) ** 2, cirq.CZ(q0, q1) ** -1, cirq.measure(q0, q1) ) @@ -129,15 +129,22 @@ def three_qubit_gates(): ) -def noise_channels(): +def bit_flip(): q = cirq.LineQubit(0) return cirq.Circuit( cirq.X(q), cirq.bit_flip(0.1).on(q), + ) + + +def amplitude_damping(): + q = cirq.LineQubit(0) + + # NOTE: currently not supported -- marked as xfail below + return cirq.Circuit( cirq.amplitude_damp(0.1).on(q), cirq.generalized_amplitude_damp(p=0.1, gamma=0.05).on(q), - cirq.measure(q), ) @@ -160,7 +167,7 @@ def nested_circuit(): cirq.CircuitOperation( cirq.Circuit(cirq.H(q[1]), cirq.CX(q[1], q[2])).freeze(), use_repetition_ids=False, - ).controlled_by(q[0]), + ), cirq.measure(*q), ) @@ -177,7 +184,8 @@ def nested_circuit(): two_qubit_pow_gates, swap_circuit, three_qubit_gates, - noise_channels, + nested_circuit, + bit_flip, depolarizing_channels, ], ) @@ -211,12 +219,6 @@ def test_return_register(): assert kernel.return_type.body.is_subseteq(ilist.IListType) -@pytest.mark.xfail -def test_nested_circuit(): - # TODO: lowering for CircuitOperation - test_circuit(nested_circuit) - - def test_passing_in_register(): circuit = pow_gate_circuit() print(circuit) @@ -407,3 +409,8 @@ def multi_arg(n: int, p: float): circuit = emit_circuit(multi_arg, args=(3, 0.1)) print(circuit) + + +@pytest.mark.xfail +def test_amplitude_damping(): + test_circuit(amplitude_damping) diff --git a/test/cirq_utils/test_clifford_to_cirq.py b/test/cirq_utils/test_clifford_to_cirq.py new file mode 100644 index 00000000..46c3c259 --- /dev/null +++ b/test/cirq_utils/test_clifford_to_cirq.py @@ -0,0 +1,283 @@ +import math +import typing + +import cirq +import pytest +from kirin.emit import EmitError +from kirin.dialects import ilist + +from bloqade import squin +from bloqade.cirq_utils import emit, emit_circuit + + +def test_pauli(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + q2 = squin.qubit.new(4) + x = squin.op.x() + y = squin.op.y() + z = squin.op.z() + squin.qubit.apply(x, q[0]) + squin.qubit.apply(y, q2[0]) + squin.qubit.apply(z, q2[3]) + + circuit = emit_circuit(main) + + print(circuit) + + qbits = circuit.all_qubits() + assert len(qbits) == 3 + assert isinstance(qbit := list(qbits)[-1], cirq.LineQubit) + assert qbit.x == 5 + + +@pytest.mark.parametrize("op_name", ["h", "s", "t", "x", "y", "z"]) +def test_basic_op(op_name: str): + @squin.kernel + def main(): + q = squin.qubit.new(1) + getattr(squin, op_name)(q) + + emit_circuit(main) + + +def test_control(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + squin.h(q[0]) + squin.cx(q[0], q[1]) + + circuit = emit_circuit(main) + + print(circuit) + + assert len(circuit) == 2 + assert circuit[1].operations[0].gate == cirq.CNOT + + +def test_custom_qubits(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + squin.h(q[0]) + squin.cx(q[0], q[1]) + + qubits = [cirq.GridQubit(0, 1), cirq.GridQubit(2, 2)] + circuit = emit_circuit(main, qubits=qubits) + + print(circuit) + + circuit_qubits = circuit.all_qubits() + assert len(circuit_qubits) == 2 + assert frozenset(qubits) == circuit_qubits + + +def test_composed_kernels(): + @squin.kernel + def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.h(q_[0]) + + @squin.kernel + def main(): + q = squin.qubit.new(2) + sub_kernel(q) + + circuit = emit_circuit(main) + + print(circuit) + + assert len(circuit) == 1 + assert len(circuit[0].operations) == 1 + assert isinstance(circuit[0].operations[0], cirq.GateOperation) + + +def test_nested_kernels(): + @squin.kernel + def sub_kernel2(q2_: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.cx(q2_[0], q2_[1]) + + @squin.kernel + def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.h(q_[0]) + sub_kernel2(q_) + + @squin.kernel + def main(): + q = squin.qubit.new(2) + sub_kernel(q) + + circuit = emit_circuit(main) + + print(circuit) + + +def test_return_value(): + @squin.kernel + def sub_kernel(): + q = squin.qubit.new(2) + squin.h(q[0]) + squin.cx(q[0], q[1]) + return q + + @squin.kernel + def main(): + q = sub_kernel() + squin.h(q[0]) + + circuit = emit_circuit(main) + + print(circuit) + + with pytest.raises(EmitError): + emit_circuit(sub_kernel) + + @squin.kernel + def main2(): + q = sub_kernel() + squin.h(q[0]) + return q + + circuit2 = emit_circuit(main2, ignore_returns=True) + print(circuit2) + + assert circuit2 == circuit + + +def test_return_qubits(): + @squin.kernel + def sub_kernel(q: ilist.IList[squin.qubit.Qubit, typing.Any]): + squin.h(q[0]) + q2 = squin.qubit.new(3) + squin.cx(q[0], q2[2]) + return q2 + + @squin.kernel + def main(): + q = squin.qubit.new(2) + q2_ = sub_kernel(q) + squin.x(q2_[0]) + + circuit = emit_circuit(main) + + print(circuit) + + +def test_measurement(): + @squin.kernel + def main(): + q = squin.qubit.new(2) + squin.broadcast.y(q) + squin.qubit.measure(q) + + circuit = emit_circuit(main) + + print(circuit) + + +def test_adjoint(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.s(q[0]) + squin.s_adj(q[0]) + + circuit = emit_circuit(main) + print(circuit) + + +def test_u3(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.u3(0.323, 1.123, math.pi / 7, q[0]) + + circuit = emit_circuit(main) + print(circuit) + + +def test_shift(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.shift(math.pi / 7, q[0]) + + circuit = emit_circuit(main) + print(circuit) + + +def test_invoke_cache(): + @squin.kernel + def sub_kernel(q_: squin.qubit.Qubit): + squin.h(q_) + + @squin.kernel + def main(): + q = squin.qubit.new(2) + q0 = q[0] + sub_kernel(q0) + sub_kernel(q[1]) + sub_kernel(q0) + + target = emit.base.EmitCirq(main.dialects) + + circuit = target.run(main, ()) + + print(circuit) + + # caches as well as squin.h and squin.broadcast.h with the different qubits + assert len(target._cached_invokes) == 6 + + +def test_rot(): + @squin.kernel + def main(): + q = squin.qubit.new(1) + squin.rx(math.pi / 2, q[0]) + + circuit = emit_circuit(main) + + print(circuit) + + assert circuit[0].operations[0].gate == cirq.Rx(rads=math.pi / 2) + + +def test_additional_stmts(): + @squin.kernel + def main(): + q = squin.qubit.new(3) + squin.rot(math.pi / 4, math.pi / 2, -math.pi / 4, q[0]) + squin.sqrt_x(q[1]) + squin.sqrt_y(q[2]) + + main.print() + + circuit = emit_circuit(main) + + print(circuit) + + q = cirq.LineQubit.range(3) + expected_circuit = cirq.Circuit( + cirq.Rz(rads=math.pi / 4).on(q[0]), + cirq.Ry(rads=math.pi / 2).on(q[0]), + cirq.Rz(rads=-math.pi / 4).on(q[0]), + cirq.X(q[1]) ** 0.5, + cirq.Y(q[2]) ** 0.5, + ) + + assert circuit == expected_circuit + + +def test_return_measurement(): + + @squin.kernel + def coinflip(): + qubit = squin.qubit.new(1)[0] + squin.h(qubit) + return squin.qubit.measure(qubit) + + coinflip.print() + + circuit = emit_circuit(coinflip, ignore_returns=True) + print(circuit) diff --git a/test/cirq_utils/squin_noise_to_cirq.py b/test/cirq_utils/test_squin_noise_to_cirq.py similarity index 100% rename from test/cirq_utils/squin_noise_to_cirq.py rename to test/cirq_utils/test_squin_noise_to_cirq.py diff --git a/test/cirq_utils/test_squin_to_cirq.py b/test/cirq_utils/test_squin_to_cirq.py deleted file mode 100644 index f3cf2fd3..00000000 --- a/test/cirq_utils/test_squin_to_cirq.py +++ /dev/null @@ -1,470 +0,0 @@ -import math -import typing - -import cirq -import pytest -from kirin.emit import EmitError -from kirin.passes import inline -from kirin.dialects import ilist - -from bloqade import squin -from bloqade.cirq_utils import emit, emit_circuit - - -def test_pauli(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - q2 = squin.qubit.new(4) - x = squin.op.x() - y = squin.op.y() - z = squin.op.z() - squin.qubit.apply(x, q[0]) - squin.qubit.apply(y, q2[0]) - squin.qubit.apply(z, q2[3]) - - circuit = emit_circuit(main) - - print(circuit) - - qbits = circuit.all_qubits() - assert len(qbits) == 3 - assert isinstance(qbit := list(qbits)[-1], cirq.LineQubit) - assert qbit.x == 5 - - -@pytest.mark.parametrize("op_name", ["h", "s", "t", "x", "y", "z"]) -def test_basic_op(op_name: str): - @squin.kernel - def main(): - q = squin.qubit.new(1) - op_ = getattr(squin.op, op_name)() - squin.qubit.apply(op_, q) - - emit_circuit(main) - - -def test_control(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) - - circuit = emit_circuit(main) - - print(circuit) - - assert len(circuit) == 2 - assert circuit[1].operations[0].gate == cirq.CNOT - - -def test_custom_qubits(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q) - - qubits = [cirq.GridQubit(0, 1), cirq.GridQubit(2, 2)] - circuit = emit_circuit(main, qubits=qubits) - - print(circuit) - - circuit_qubits = circuit.all_qubits() - assert len(circuit_qubits) == 2 - assert frozenset(qubits) == circuit_qubits - - -def test_composed_kernels(): - @squin.kernel - def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q_[0]) - - @squin.kernel - def main(): - q = squin.qubit.new(2) - sub_kernel(q) - - circuit = emit_circuit(main) - - print(circuit) - - assert len(circuit) == 1 - assert len(circuit[0].operations) == 1 - assert isinstance(circuit[0].operations[0], cirq.CircuitOperation) - - -def test_nested_kernels(): - @squin.kernel - def sub_kernel2(q2_: ilist.IList[squin.qubit.Qubit, typing.Any]): - cx = squin.op.control(squin.op.x(), n_controls=1) - squin.qubit.apply(cx, q2_[0], q2_[1]) - - @squin.kernel - def sub_kernel(q_: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q_[0]) - id = squin.op.identity(sites=1) - squin.qubit.apply(id, q_[1]) - sub_kernel2(q_) - - @squin.kernel - def main(): - q = squin.qubit.new(2) - sub_kernel(q) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_return_value(): - @squin.kernel - def sub_kernel(q: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - squin.qubit.apply(cx, q[0], q[1]) - return h - - @squin.kernel - def main(): - q = squin.qubit.new(2) - h_ = sub_kernel(q) - squin.qubit.apply(h_, q[1]) - - circuit = emit_circuit(main) - - print(circuit) - - with pytest.raises(EmitError): - emit_circuit(sub_kernel) - - @squin.kernel - def main2(): - q = squin.qubit.new(2) - h_ = sub_kernel(q) - squin.qubit.apply(h_, q[1]) - return h_ - - circuit2 = emit_circuit(main2, ignore_returns=True) - print(circuit2) - - assert circuit2 == circuit - - -def test_return_qubits(): - @squin.kernel - def sub_kernel(q: ilist.IList[squin.qubit.Qubit, typing.Any]): - h = squin.op.h() - squin.qubit.apply(h, q[0]) - cx = squin.op.cx() - q2 = squin.qubit.new(3) - squin.qubit.apply(cx, [q[0], q2[2]]) - return q2 - - @squin.kernel - def main(): - q = squin.qubit.new(2) - q2_ = sub_kernel(q) - squin.qubit.apply(squin.op.x(), q2_[0]) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_measurement(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - y = squin.op.y() - squin.qubit.broadcast(y, q) - squin.qubit.measure(q) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_kron(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - x = squin.op.x() - xx = squin.op.kron(x, x) - squin.qubit.apply(xx, q) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_mult(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - x = squin.op.x() - y = squin.op.y() - m = squin.op.mult(x, y) - squin.qubit.apply(m, q[0]) - - circuit = emit_circuit(main) - - print(circuit) - - -def test_projector(): - @squin.kernel - def main(): - q = squin.qubit.new(2) - h = squin.op.h() - squin.qubit.broadcast(h, q) - p0 = squin.op.p0() - p1 = squin.op.p1() - squin.qubit.apply(p0, q[0]) - squin.qubit.apply(p1, q[1]) - squin.qubit.measure(q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_sp_sn(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - sp = squin.op.spin_p() - sn = squin.op.spin_n() - squin.qubit.apply(sp, q) - squin.qubit.apply(sn, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_adjoint(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - s = squin.op.s() - s_dagger = squin.op.adjoint(s) - squin.qubit.apply(s, q) - squin.qubit.apply(s_dagger, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_u3(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - u3 = squin.op.u(0.323, 1.123, math.pi / 7) - squin.qubit.apply(u3, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_scale(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - x = squin.op.x() - s = 2 * x - squin.qubit.apply(s, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_phase(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - p = squin.op.phase(math.pi / 3) - squin.qubit.apply(p, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_shift(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - p = squin.op.shift(math.pi / 7) - squin.qubit.apply(p, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_reset(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - r = squin.op.reset() - squin.qubit.apply(r, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_pauli_string(): - @squin.kernel - def main(): - p = squin.op.pauli_string(string="XYZ") - q = squin.qubit.new(3) - squin.qubit.apply(p, q) - - circuit = emit_circuit(main) - print(circuit) - - -def test_invoke_cache(): - @squin.kernel - def sub_kernel(q_: squin.qubit.Qubit): - squin.qubit.apply(squin.op.h(), q_) - - @squin.kernel - def main(): - q = squin.qubit.new(2) - q0 = q[0] - sub_kernel(q0) - sub_kernel(q[1]) - sub_kernel(q0) - - target = emit.base.EmitCirq(main.dialects) - - circuit = target.run(main, ()) - - print(circuit) - - assert len(target._cached_circuit_operations) == 2 - - -def test_rot(): - @squin.kernel - def main(): - axis = squin.op.x() - q = squin.qubit.new(1) - r = squin.op.rot(axis=axis, angle=math.pi / 2) - squin.qubit.apply(r, q[0]) - - circuit = emit_circuit(main) - - print(circuit) - - assert circuit[0].operations[0].gate == cirq.Rx(rads=math.pi / 2) - - @squin.kernel - def main2(): - x = squin.op.x() - y = squin.op.y() - q = squin.qubit.new(1) - r = squin.op.rot(axis=x * y, angle=0.123) - squin.qubit.apply(r, q[0]) - - with pytest.raises(EmitError): - emit_circuit(main2) - - @squin.kernel - def main3(): - op = squin.op.h() - q = squin.qubit.new(1) - r = squin.op.rot(axis=op, angle=0.123) - squin.qubit.apply(r, q[0]) - - with pytest.raises(EmitError): - emit_circuit(main3) - - -def test_additional_stmts(): - @squin.kernel - def main(): - x = squin.op.x() - r = squin.op.rot(x, 0.123) - q = squin.qubit.new(3) - squin.qubit.apply(r, q[0]) - sqrt_x = squin.op.sqrt_x() - sqrt_y = squin.op.sqrt_y() - squin.qubit.apply(sqrt_x, q[1]) - squin.qubit.apply(sqrt_y, q[2]) - - main.print() - - circuit = emit_circuit(main) - - print(circuit) - - q = cirq.LineQubit.range(3) - expected_circuit = cirq.Circuit( - cirq.Rx(rads=0.123).on(q[0]), - cirq.X(q[1]) ** 0.5, - cirq.Y(q[2]) ** 0.5, - ) - - assert circuit == expected_circuit - - -def test_return_measurement(): - - @squin.kernel - def coinflip(): - qubit = squin.qubit.new(1)[0] - squin.gate.h(qubit) - return squin.qubit.measure(qubit) - - coinflip.print() - - circuit = emit_circuit(coinflip, ignore_returns=True) - print(circuit) - - -def test_reset_to_one(): - @squin.kernel - def main(): - q = squin.qubit.new(1) - squin.gate.h(q[0]) - squin.gate.reset_to_one(q[0]) - - inline.InlinePass(main.dialects)(main) - - main.print() - - circuit = emit_circuit(main) - - q = cirq.LineQubit(0) - expected_circuit = cirq.Circuit( - cirq.H(q), - cirq.reset(q), - cirq.X(q), - ) - - print(circuit) - - assert circuit == expected_circuit - - -def test_overlapping_operations(): - @squin.kernel - def main(): - q = squin.qubit.new(5) - - x = squin.op.x() - y = squin.op.y() - op = x * y - - squin.qubit.broadcast(op, q) - - circuit = emit_circuit(main) - - print(circuit)