Skip to content

Commit cc718a1

Browse files
david-plweinbe58
andcommitted
Emit a cirq.Circuit from a squin kernel (#311)
There's still quite a few things to do here: * Some of the operator statements (e.g. `Kron`) don't have a 1:1 mapping to a cirq gate. They probably need a "runtime", since they are easy enough to apply to qubits. Alternatively, we could use custom gates in cirq, but I'm not much of a fan. * Methods for the noise statements are still missing. * There is some very hacky stuff I'm doing right now in order to obtain the `BlockArguments` and `ReturnValues` in the frame for nested kernels. There's probably a better way to do this, but I don't know how (yet). --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 2e551e9 commit cc718a1

File tree

7 files changed

+978
-2
lines changed

7 files changed

+978
-2
lines changed

src/bloqade/squin/cirq/__init__.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
from typing import Any
1+
from typing import Any, Sequence
22

33
import cirq
44
from kirin import ir, types
5+
from kirin.emit import EmitError
56
from kirin.dialects import func
67

78
from . import lowering as lowering
89
from .. import kernel
10+
11+
# NOTE: just to register methods
12+
from .emit import op as op, qubit as qubit
913
from .lowering import Squin
14+
from .emit.emit_circuit import EmitCirq
1015

1116

1217
def load_circuit(
@@ -87,3 +92,109 @@ def load_circuit(
8792
dialects=dialects,
8893
code=code,
8994
)
95+
96+
97+
def emit_circuit(
98+
mt: ir.Method,
99+
qubits: Sequence[cirq.Qid] | None = None,
100+
) -> cirq.Circuit:
101+
"""Converts a squin.kernel method to a cirq.Circuit object.
102+
103+
Args:
104+
mt (ir.Method): The kernel method from which to construct the circuit.
105+
106+
Keyword Args:
107+
qubits (Sequence[cirq.Qid] | None):
108+
A list of qubits to use as the qubits in the circuit. Defaults to None.
109+
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
110+
statement in the order they appear inside the kernel.
111+
**Note**: If a list of qubits is provided, make sure that there is a sufficient
112+
number of qubits for the resulting circuit.
113+
114+
## Examples:
115+
116+
Here's a very basic example:
117+
118+
```python
119+
from bloqade import squin
120+
121+
@squin.kernel
122+
def main():
123+
q = squin.qubit.new(2)
124+
h = squin.op.h()
125+
squin.qubit.apply(h, q[0])
126+
cx = squin.op.cx()
127+
squin.qubit.apply(cx, q)
128+
129+
circuit = squin.cirq.emit_circuit(main)
130+
131+
print(circuit)
132+
```
133+
134+
You can also compose multiple kernels. Those are emitted as subcircuits within the "main" circuit.
135+
Subkernels can accept arguments and return a value.
136+
137+
```python
138+
from bloqade import squin
139+
from kirin.dialects import ilist
140+
from typing import Literal
141+
import cirq
142+
143+
@squin.kernel
144+
def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
145+
h = squin.op.h()
146+
squin.qubit.apply(h, q[0])
147+
cx = squin.op.cx()
148+
squin.qubit.apply(cx, q)
149+
return cx
150+
151+
@squin.kernel
152+
def main():
153+
q = squin.qubit.new(2)
154+
cx = entangle(q)
155+
q2 = squin.qubit.new(3)
156+
squin.qubit.apply(cx, [q[1], q2[2]])
157+
158+
159+
# custom list of qubits on grid
160+
qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
161+
162+
circuit = squin.cirq.emit_circuit(main, qubits=qubits)
163+
print(circuit)
164+
165+
```
166+
167+
We also passed in a custom list of qubits above. This allows you to provide a custom geometry
168+
and manipulate the qubits in other circuits directly written in cirq as well.
169+
"""
170+
171+
if isinstance(mt.code, func.Function) and not mt.code.signature.output.is_subseteq(
172+
types.NoneType
173+
):
174+
raise EmitError(
175+
"The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
176+
)
177+
178+
emitter = EmitCirq(qubits=qubits)
179+
return emitter.run(mt, args=())
180+
181+
182+
def dump_circuit(mt: ir.Method, qubits: Sequence[cirq.Qid] | None = None, **kwargs):
183+
"""Converts a squin.kernel method to a cirq.Circuit object and dumps it as JSON.
184+
185+
This just runs `emit_circuit` and calls the `cirq.to_json` function to emit a JSON.
186+
187+
Args:
188+
mt (ir.Method): The kernel method from which to construct the circuit.
189+
190+
Keyword Args:
191+
qubits (Sequence[cirq.Qid] | None):
192+
A list of qubits to use as the qubits in the circuit. Defaults to None.
193+
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
194+
statement in the order they appear inside the kernel.
195+
**Note**: If a list of qubits is provided, make sure that there is a sufficient
196+
number of qubits for the resulting circuit.
197+
198+
"""
199+
circuit = emit_circuit(mt, qubits=qubits)
200+
return cirq.to_json(circuit, **kwargs)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from typing import Sequence
2+
from dataclasses import field, dataclass
3+
4+
import cirq
5+
from kirin import ir
6+
from kirin.emit import EmitABC, EmitError, EmitFrame
7+
from kirin.interp import MethodTable, impl
8+
from kirin.dialects import func
9+
from typing_extensions import Self
10+
11+
from ... import kernel
12+
13+
14+
@dataclass
15+
class EmitCirqFrame(EmitFrame):
16+
qubit_index: int = 0
17+
qubits: Sequence[cirq.Qid] | None = None
18+
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)
19+
20+
21+
def _default_kernel():
22+
return kernel
23+
24+
25+
@dataclass
26+
class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
27+
keys = ["emit.cirq", "main"]
28+
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
29+
void = cirq.Circuit()
30+
qubits: Sequence[cirq.Qid] | None = None
31+
_cached_circuit_operations: dict[int, cirq.CircuitOperation] = field(
32+
init=False, default_factory=dict
33+
)
34+
35+
def initialize(self) -> Self:
36+
return super().initialize()
37+
38+
def initialize_frame(
39+
self, code: ir.Statement, *, has_parent_access: bool = False
40+
) -> EmitCirqFrame:
41+
return EmitCirqFrame(
42+
code, has_parent_access=has_parent_access, qubits=self.qubits
43+
)
44+
45+
def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
46+
return self.run_callable(method.code, args)
47+
48+
def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
49+
for stmt in block.stmts:
50+
result = self.eval_stmt(frame, stmt)
51+
if isinstance(result, tuple):
52+
frame.set_values(stmt.results, result)
53+
54+
return frame.circuit
55+
56+
57+
@func.dialect.register(key="emit.cirq")
58+
class FuncEmit(MethodTable):
59+
60+
@impl(func.Function)
61+
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
62+
emit.run_ssacfg_region(frame, stmt.body, ())
63+
return (frame.circuit,)
64+
65+
@impl(func.Invoke)
66+
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
67+
stmt_hash = hash((stmt.callee, stmt.inputs))
68+
if (
69+
cached_circuit_op := emit._cached_circuit_operations.get(stmt_hash)
70+
) is not None:
71+
# NOTE: cache hit
72+
frame.circuit.append(cached_circuit_op)
73+
return ()
74+
75+
ret = stmt.result
76+
77+
with emit.new_frame(stmt.callee.code, has_parent_access=True) as sub_frame:
78+
sub_frame.qubit_index = frame.qubit_index
79+
sub_frame.qubits = frame.qubits
80+
81+
region = stmt.callee.callable_region
82+
if len(region.blocks) > 1:
83+
raise EmitError(
84+
"Subroutine with more than a single block encountered. This is not supported!"
85+
)
86+
87+
# NOTE: get the arguments, "self" is just an empty circuit
88+
method_self = emit.void
89+
args = [frame.get(arg_) for arg_ in stmt.inputs]
90+
emit.run_ssacfg_region(
91+
sub_frame, stmt.callee.callable_region, args=(method_self, *args)
92+
)
93+
sub_circuit = sub_frame.circuit
94+
95+
# NOTE: check to see if the call terminates with a return value and fetch the value;
96+
# we don't support multiple return statements via control flow so we just pick the first one
97+
block = region.blocks[0]
98+
return_stmt = next(
99+
(stmt for stmt in block.stmts if isinstance(stmt, func.Return)), None
100+
)
101+
if return_stmt is not None:
102+
frame.entries[ret] = sub_frame.get(return_stmt.value)
103+
104+
circuit_op = cirq.CircuitOperation(
105+
sub_circuit.freeze(), use_repetition_ids=False
106+
)
107+
emit._cached_circuit_operations[stmt_hash] = circuit_op
108+
frame.circuit.append(circuit_op)
109+
return ()

src/bloqade/squin/cirq/emit/op.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import math
2+
3+
import cirq
4+
import numpy as np
5+
from kirin.interp import MethodTable, impl
6+
7+
from ... import op
8+
from .runtime import (
9+
SnRuntime,
10+
SpRuntime,
11+
U3Runtime,
12+
KronRuntime,
13+
MultRuntime,
14+
ScaleRuntime,
15+
AdjointRuntime,
16+
ControlRuntime,
17+
UnitaryRuntime,
18+
HermitianRuntime,
19+
ProjectorRuntime,
20+
OperatorRuntimeABC,
21+
PauliStringRuntime,
22+
)
23+
from .emit_circuit import EmitCirq, EmitCirqFrame
24+
25+
26+
@op.dialect.register(key="emit.cirq")
27+
class EmitCirqOpMethods(MethodTable):
28+
29+
@impl(op.stmts.X)
30+
@impl(op.stmts.Y)
31+
@impl(op.stmts.Z)
32+
@impl(op.stmts.H)
33+
def hermitian(
34+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
35+
):
36+
cirq_op = getattr(cirq, stmt.name.upper())
37+
return (HermitianRuntime(cirq_op),)
38+
39+
@impl(op.stmts.S)
40+
@impl(op.stmts.T)
41+
def unitary(
42+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
43+
):
44+
cirq_op = getattr(cirq, stmt.name.upper())
45+
return (UnitaryRuntime(cirq_op),)
46+
47+
@impl(op.stmts.P0)
48+
@impl(op.stmts.P1)
49+
def projector(
50+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P0 | op.stmts.P1
51+
):
52+
return (ProjectorRuntime(isinstance(stmt, op.stmts.P1)),)
53+
54+
@impl(op.stmts.Sn)
55+
def sn(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sn):
56+
return (SnRuntime(),)
57+
58+
@impl(op.stmts.Sp)
59+
def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):
60+
return (SpRuntime(),)
61+
62+
@impl(op.stmts.Identity)
63+
def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
64+
op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites))
65+
return (op,)
66+
67+
@impl(op.stmts.Control)
68+
def control(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Control):
69+
op: OperatorRuntimeABC = frame.get(stmt.op)
70+
return (ControlRuntime(op, stmt.n_controls),)
71+
72+
@impl(op.stmts.Kron)
73+
def kron(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Kron):
74+
lhs = frame.get(stmt.lhs)
75+
rhs = frame.get(stmt.rhs)
76+
op = KronRuntime(lhs, rhs)
77+
return (op,)
78+
79+
@impl(op.stmts.Mult)
80+
def mult(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Mult):
81+
lhs = frame.get(stmt.lhs)
82+
rhs = frame.get(stmt.rhs)
83+
op = MultRuntime(lhs, rhs)
84+
return (op,)
85+
86+
@impl(op.stmts.Adjoint)
87+
def adjoint(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Adjoint):
88+
op_ = frame.get(stmt.op)
89+
return (AdjointRuntime(op_),)
90+
91+
@impl(op.stmts.Scale)
92+
def scale(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Scale):
93+
op_ = frame.get(stmt.op)
94+
factor = frame.get(stmt.factor)
95+
return (ScaleRuntime(operator=op_, factor=factor),)
96+
97+
@impl(op.stmts.U3)
98+
def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.U3):
99+
theta = frame.get(stmt.theta)
100+
phi = frame.get(stmt.phi)
101+
lam = frame.get(stmt.lam)
102+
return (U3Runtime(theta=theta, phi=phi, lam=lam),)
103+
104+
@impl(op.stmts.PhaseOp)
105+
def phaseop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PhaseOp):
106+
theta = frame.get(stmt.theta)
107+
op_ = HermitianRuntime(cirq.IdentityGate(num_qubits=1))
108+
return (ScaleRuntime(operator=op_, factor=np.exp(1j * theta)),)
109+
110+
@impl(op.stmts.ShiftOp)
111+
def shiftop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ShiftOp):
112+
theta = frame.get(stmt.theta)
113+
114+
# NOTE: ShiftOp(theta) == U3(pi, theta, 0)
115+
return (U3Runtime(math.pi, theta, 0),)
116+
117+
@impl(op.stmts.Reset)
118+
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
119+
return (HermitianRuntime(cirq.ResetChannel()),)
120+
121+
@impl(op.stmts.PauliString)
122+
def pauli_string(
123+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
124+
):
125+
return (PauliStringRuntime(stmt.string),)

0 commit comments

Comments
 (0)