Skip to content
Merged
2 changes: 1 addition & 1 deletion src/bloqade/cirq_utils/emit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# NOTE: just to register methods
from . import op as op, noise as noise, qubit as qubit
from . import op as op, noise as noise, qubit as qubit, clifford as clifford
from .base import emit_circuit as emit_circuit
59 changes: 31 additions & 28 deletions src/bloqade/cirq_utils/emit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,15 @@ def emit_circuit(

```python
from bloqade import squin
from bloqade.cirq_utils import emit_circuit

@squin.kernel
def main():
q = squin.qubit.new(2)
h = squin.op.h()
squin.qubit.apply(h, q[0])
cx = squin.op.cx()
squin.qubit.apply(cx, q)
squin.h(q[0])
squin.cx(q[0], q[1])

circuit = squin.cirq.emit_circuit(main)
circuit = emit_circuit(main)

print(circuit)
```
Expand All @@ -62,30 +61,25 @@ def main():

```python
from bloqade import squin
from bloqade.cirq_utils import emit_circuit
from kirin.dialects import ilist
from typing import Literal
import cirq

@squin.kernel
def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
h = squin.op.h()
squin.qubit.apply(h, q[0])
cx = squin.op.cx()
squin.qubit.apply(cx, q)
return cx
squin.h(q[0])
squin.cx(q[0], q[1])

@squin.kernel
def main():
q = squin.qubit.new(2)
cx = entangle(q)
q2 = squin.qubit.new(3)
squin.qubit.apply(cx, [q[1], q2[2]])

entangle(q)

# custom list of qubits on grid
qubits = [cirq.GridQubit(i, i+1) for i in range(5)]

circuit = squin.cirq.emit_circuit(main, circuit_qubits=qubits)
circuit = emit_circuit(main, circuit_qubits=qubits)
print(circuit)

```
Expand Down Expand Up @@ -115,7 +109,7 @@ def main():
f"The method from which you're trying to emit a circuit takes {len(mt.args)} as input, but you passed in {len(args)} via the `args` keyword!"
)

emitter = EmitCirq(qubits=qubits)
emitter = EmitCirq(qubits=circuit_qubits)

return emitter.run(mt, args=args)

Expand All @@ -137,7 +131,7 @@ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
void = cirq.Circuit()
qubits: Sequence[cirq.Qid] | None = None
_cached_circuit_operations: dict[int, cirq.CircuitOperation] = field(
_cached_invokes: dict[int, cirq.FrozenCircuit] = field(
init=False, default_factory=dict
)

Expand Down Expand Up @@ -184,7 +178,7 @@ def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:


@func.dialect.register(key="emit.cirq")
class FuncEmit(MethodTable):
class __FuncEmit(MethodTable):

@impl(func.Function)
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
Expand All @@ -193,12 +187,22 @@ def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):

@impl(func.Invoke)
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
stmt_hash = hash((stmt.callee, stmt.inputs))
if (
cached_circuit_op := emit._cached_circuit_operations.get(stmt_hash)
) is not None:
try:
stmt_hash = hash(
(stmt.callee, tuple(frame.get(input) for input in stmt.inputs))
)
except (TypeError, interp.InterpreterError):
# NOTE: avoid unhashable types and missing keys, just don't cache them
stmt_hash = None

if stmt_hash is not None:
cached_circuit = emit._cached_invokes.get(stmt_hash)
else:
cached_circuit = None

if cached_circuit is not None:
# NOTE: cache hit
frame.circuit.append(cached_circuit_op)
frame.circuit.append(cached_circuit.all_operations())
return ()

ret = stmt.result
Expand Down Expand Up @@ -230,9 +234,8 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
if return_stmt is not None:
frame.entries[ret] = sub_frame.get(return_stmt.value)

circuit_op = cirq.CircuitOperation(
sub_circuit.freeze(), use_repetition_ids=False
)
emit._cached_circuit_operations[stmt_hash] = circuit_op
frame.circuit.append(circuit_op)
if stmt_hash is not None:
emit._cached_invokes[stmt_hash] = sub_circuit.freeze()

frame.circuit.append(sub_circuit.all_operations())
return ()
90 changes: 90 additions & 0 deletions src/bloqade/cirq_utils/emit/clifford.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import math

import cirq
from kirin.interp import MethodTable, impl

from bloqade.squin import clifford

from .base import EmitCirq, EmitCirqFrame


@clifford.dialect.register(key="emit.cirq")
class __EmitCirqCliffordMethods(MethodTable):

@impl(clifford.stmts.X)
@impl(clifford.stmts.Y)
@impl(clifford.stmts.Z)
@impl(clifford.stmts.H)
def hermitian(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.SingleQubitGate
):
qubits = frame.get(stmt.qubits)
cirq_op = getattr(cirq, stmt.name.upper())
frame.circuit.append(cirq_op.on_each(qubits))
return ()

@impl(clifford.stmts.S)
@impl(clifford.stmts.T)
def unitary(
self,
emit: EmitCirq,
frame: EmitCirqFrame,
stmt: clifford.stmts.SingleQubitNonHermitianGate,
):
qubits = frame.get(stmt.qubits)
cirq_op = getattr(cirq, stmt.name.upper())
if stmt.adjoint:
cirq_op = cirq_op ** (-1)

frame.circuit.append(cirq_op.on_each(qubits))
return ()

@impl(clifford.stmts.SqrtX)
@impl(clifford.stmts.SqrtY)
def sqrt(
self,
emit: EmitCirq,
frame: EmitCirqFrame,
stmt: clifford.stmts.SqrtX | clifford.stmts.SqrtY,
):
qubits = frame.get(stmt.qubits)

exponent = 0.5
if stmt.adjoint:
exponent *= -1

if isinstance(stmt, clifford.stmts.SqrtX):
cirq_op = cirq.XPowGate(exponent=exponent)
else:
cirq_op = cirq.YPowGate(exponent=exponent)

frame.circuit.append(cirq_op.on_each(qubits))
return ()

@impl(clifford.stmts.CX)
@impl(clifford.stmts.CY)
@impl(clifford.stmts.CZ)
def control(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.ControlledGate
):
controls = frame.get(stmt.controls)
targets = frame.get(stmt.targets)
cirq_op = getattr(cirq, stmt.name.upper())
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
frame.circuit.append(cirq_op.on_each(cirq_qubits))
return ()

@impl(clifford.stmts.Rx)
@impl(clifford.stmts.Ry)
@impl(clifford.stmts.Rz)
def rot(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.RotationGate
):
qubits = frame.get(stmt.qubits)

turns = frame.get(stmt.angle)
angle = turns * 2 * math.pi
cirq_op = getattr(cirq, stmt.name.title())(rads=angle)

frame.circuit.append(cirq_op.on_each(qubits))
return ()
92 changes: 50 additions & 42 deletions src/bloqade/cirq_utils/emit/noise.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,53 @@
import cirq
from kirin.emit import EmitError
from kirin.interp import MethodTable, impl

from bloqade.squin import noise

from .base import EmitCirq, EmitCirqFrame
from .runtime import (
KronRuntime,
BasicOpRuntime,
OperatorRuntimeABC,
PauliStringRuntime,
)


@noise.dialect.register(key="emit.cirq")
class EmitCirqNoiseMethods(MethodTable):
class __EmitCirqNoiseMethods(MethodTable):

two_qubit_paulis = (
"IX",
"IY",
"IZ",
"XI",
"XX",
"XY",
"XZ",
"YI",
"YX",
"YY",
"YZ",
"ZI",
"ZX",
"ZY",
"ZZ",
)

@impl(noise.stmts.Depolarize)
def depolarize(
self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize
):
p = frame.get(stmt.p)
gate = cirq.depolarize(p, n_qubits=1)
return (BasicOpRuntime(gate=gate),)
qubits = frame.get(stmt.qubits)
cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits)
frame.circuit.append(cirfq_op)
return ()

@impl(noise.stmts.Depolarize2)
def depolarize2(
self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize2
):
p = frame.get(stmt.p)
gate = cirq.depolarize(p, n_qubits=2)
return (BasicOpRuntime(gate=gate),)
controls = frame.get(stmt.controls)
targets = frame.get(stmt.targets)
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits)
frame.circuit.append(cirq_op)
return ()

@impl(noise.stmts.SingleQubitPauliChannel)
def single_qubit_pauli_channel(
Expand All @@ -39,9 +56,15 @@ def single_qubit_pauli_channel(
frame: EmitCirqFrame,
stmt: noise.stmts.SingleQubitPauliChannel,
):
ps = frame.get(stmt.params)
gate = cirq.asymmetric_depolarize(*ps)
return (BasicOpRuntime(gate=gate),)
px = frame.get(stmt.px)
py = frame.get(stmt.py)
pz = frame.get(stmt.pz)
qubits = frame.get(stmt.qubits)

cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits)
frame.circuit.append(cirq_op)

return ()

@impl(noise.stmts.TwoQubitPauliChannel)
def two_qubit_pauli_channel(
Expand All @@ -50,33 +73,18 @@ def two_qubit_pauli_channel(
frame: EmitCirqFrame,
stmt: noise.stmts.TwoQubitPauliChannel,
):
ps = frame.get(stmt.params)
paulis = ("I", "X", "Y", "Z")
pauli_combinations = [
pauli1 + pauli2
for pauli1 in paulis
for pauli2 in paulis
if not (pauli1 == pauli2 == "I")
]
error_probabilities = {key: p for (key, p) in zip(pauli_combinations, ps)}
gate = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
return (BasicOpRuntime(gate),)

@staticmethod
def _op_to_key(operator: OperatorRuntimeABC) -> str:
match operator:
case KronRuntime():
key_lhs = EmitCirqNoiseMethods._op_to_key(operator.lhs)
key_rhs = EmitCirqNoiseMethods._op_to_key(operator.rhs)
return key_lhs + key_rhs
ps = frame.get(stmt.probabilities)
error_probabilities = {
key: p for (key, p) in zip(self.two_qubit_paulis, ps) if p != 0
}

case BasicOpRuntime():
return str(operator.gate)
controls = frame.get(stmt.controls)
targets = frame.get(stmt.targets)
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]

case PauliStringRuntime():
return operator.string
cirq_op = cirq.asymmetric_depolarize(
error_probabilities=error_probabilities
).on_each(cirq_qubits)
frame.circuit.append(cirq_op)

case _:
raise EmitError(
f"Unexpected operator runtime in StochasticUnitaryChannel of type {type(operator).__name__} encountered!"
)
return ()
8 changes: 5 additions & 3 deletions src/bloqade/cirq_utils/emit/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
n_qubits = frame.get(stmt.n_qubits)

if frame.qubits is not None:
cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)]
cirq_qubits = tuple(
frame.qubits[i + frame.qubit_index] for i in range(n_qubits)
)
else:
cirq_qubits = [
cirq_qubits = tuple(
cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
]
)

frame.qubit_index += n_qubits
return (cirq_qubits,)
Expand Down
Loading