Skip to content

Commit 0480e3e

Browse files
authored
Refactor cirq - squin emit and lowering (#520)
Closes #497 I had to cut down support for some cirq features in lowering, but I don't think they were used heavily. For example, we no longer have `CZPowGate` with arbitrary exponents or arbitrary controlled operations. Also, one more significant breaking change is that `func.Invoke` statements are no longer emitted as `CircuitOperation`, but are added to the circuit flat-out. I think this matches user expectations and otherwise would have been really odd given our new approach with stdlib functions. Finally, noise still needs to be done, but that is pending #495, which should be merged first, I think. **Edit**: Noise is also done now, but I also had to cut support for `AmplitudeDampingChannel`, since we can't really represent this with the current statements in the noise dialect (I think). However, I haven't seen this used anywhere so far, so I'd suggest we wait for someone to actually require that feature.
1 parent e973ebe commit 0480e3e

File tree

13 files changed

+736
-755
lines changed

13 files changed

+736
-755
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# NOTE: just to register methods
2-
from . import op as op, noise as noise, qubit as qubit
2+
from . import op as op, noise as noise, qubit as qubit, clifford as clifford
33
from .base import emit_circuit as emit_circuit

src/bloqade/cirq_utils/emit/base.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,15 @@ def emit_circuit(
4343
4444
```python
4545
from bloqade import squin
46+
from bloqade.cirq_utils import emit_circuit
4647
4748
@squin.kernel
4849
def main():
4950
q = squin.qubit.new(2)
50-
h = squin.op.h()
51-
squin.qubit.apply(h, q[0])
52-
cx = squin.op.cx()
53-
squin.qubit.apply(cx, q)
51+
squin.h(q[0])
52+
squin.cx(q[0], q[1])
5453
55-
circuit = squin.cirq.emit_circuit(main)
54+
circuit = emit_circuit(main)
5655
5756
print(circuit)
5857
```
@@ -62,30 +61,25 @@ def main():
6261
6362
```python
6463
from bloqade import squin
64+
from bloqade.cirq_utils import emit_circuit
6565
from kirin.dialects import ilist
6666
from typing import Literal
6767
import cirq
6868
6969
@squin.kernel
7070
def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
71-
h = squin.op.h()
72-
squin.qubit.apply(h, q[0])
73-
cx = squin.op.cx()
74-
squin.qubit.apply(cx, q)
75-
return cx
71+
squin.h(q[0])
72+
squin.cx(q[0], q[1])
7673
7774
@squin.kernel
7875
def main():
7976
q = squin.qubit.new(2)
80-
cx = entangle(q)
81-
q2 = squin.qubit.new(3)
82-
squin.qubit.apply(cx, [q[1], q2[2]])
83-
77+
entangle(q)
8478
8579
# custom list of qubits on grid
8680
qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
8781
88-
circuit = squin.cirq.emit_circuit(main, circuit_qubits=qubits)
82+
circuit = emit_circuit(main, circuit_qubits=qubits)
8983
print(circuit)
9084
9185
```
@@ -115,7 +109,7 @@ def main():
115109
f"The method from which you're trying to emit a circuit takes {len(mt.args)} as input, but you passed in {len(args)} via the `args` keyword!"
116110
)
117111

118-
emitter = EmitCirq(qubits=qubits)
112+
emitter = EmitCirq(qubits=circuit_qubits)
119113

120114
return emitter.run(mt, args=args)
121115

@@ -137,7 +131,7 @@ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
137131
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
138132
void = cirq.Circuit()
139133
qubits: Sequence[cirq.Qid] | None = None
140-
_cached_circuit_operations: dict[int, cirq.CircuitOperation] = field(
134+
_cached_invokes: dict[int, cirq.FrozenCircuit] = field(
141135
init=False, default_factory=dict
142136
)
143137

@@ -184,7 +178,7 @@ def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
184178

185179

186180
@func.dialect.register(key="emit.cirq")
187-
class FuncEmit(MethodTable):
181+
class __FuncEmit(MethodTable):
188182

189183
@impl(func.Function)
190184
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
@@ -193,12 +187,22 @@ def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
193187

194188
@impl(func.Invoke)
195189
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
196-
stmt_hash = hash((stmt.callee, stmt.inputs))
197-
if (
198-
cached_circuit_op := emit._cached_circuit_operations.get(stmt_hash)
199-
) is not None:
190+
try:
191+
stmt_hash = hash(
192+
(stmt.callee, tuple(frame.get(input) for input in stmt.inputs))
193+
)
194+
except (TypeError, interp.InterpreterError):
195+
# NOTE: avoid unhashable types and missing keys, just don't cache them
196+
stmt_hash = None
197+
198+
if stmt_hash is not None:
199+
cached_circuit = emit._cached_invokes.get(stmt_hash)
200+
else:
201+
cached_circuit = None
202+
203+
if cached_circuit is not None:
200204
# NOTE: cache hit
201-
frame.circuit.append(cached_circuit_op)
205+
frame.circuit.append(cached_circuit.all_operations())
202206
return ()
203207

204208
ret = stmt.result
@@ -230,9 +234,8 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
230234
if return_stmt is not None:
231235
frame.entries[ret] = sub_frame.get(return_stmt.value)
232236

233-
circuit_op = cirq.CircuitOperation(
234-
sub_circuit.freeze(), use_repetition_ids=False
235-
)
236-
emit._cached_circuit_operations[stmt_hash] = circuit_op
237-
frame.circuit.append(circuit_op)
237+
if stmt_hash is not None:
238+
emit._cached_invokes[stmt_hash] = sub_circuit.freeze()
239+
240+
frame.circuit.append(sub_circuit.all_operations())
238241
return ()
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import math
2+
3+
import cirq
4+
from kirin.interp import MethodTable, impl
5+
6+
from bloqade.squin import clifford
7+
8+
from .base import EmitCirq, EmitCirqFrame
9+
10+
11+
@clifford.dialect.register(key="emit.cirq")
12+
class __EmitCirqCliffordMethods(MethodTable):
13+
14+
@impl(clifford.stmts.X)
15+
@impl(clifford.stmts.Y)
16+
@impl(clifford.stmts.Z)
17+
@impl(clifford.stmts.H)
18+
def hermitian(
19+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.SingleQubitGate
20+
):
21+
qubits = frame.get(stmt.qubits)
22+
cirq_op = getattr(cirq, stmt.name.upper())
23+
frame.circuit.append(cirq_op.on_each(qubits))
24+
return ()
25+
26+
@impl(clifford.stmts.S)
27+
@impl(clifford.stmts.T)
28+
def unitary(
29+
self,
30+
emit: EmitCirq,
31+
frame: EmitCirqFrame,
32+
stmt: clifford.stmts.SingleQubitNonHermitianGate,
33+
):
34+
qubits = frame.get(stmt.qubits)
35+
cirq_op = getattr(cirq, stmt.name.upper())
36+
if stmt.adjoint:
37+
cirq_op = cirq_op ** (-1)
38+
39+
frame.circuit.append(cirq_op.on_each(qubits))
40+
return ()
41+
42+
@impl(clifford.stmts.SqrtX)
43+
@impl(clifford.stmts.SqrtY)
44+
def sqrt(
45+
self,
46+
emit: EmitCirq,
47+
frame: EmitCirqFrame,
48+
stmt: clifford.stmts.SqrtX | clifford.stmts.SqrtY,
49+
):
50+
qubits = frame.get(stmt.qubits)
51+
52+
exponent = 0.5
53+
if stmt.adjoint:
54+
exponent *= -1
55+
56+
if isinstance(stmt, clifford.stmts.SqrtX):
57+
cirq_op = cirq.XPowGate(exponent=exponent)
58+
else:
59+
cirq_op = cirq.YPowGate(exponent=exponent)
60+
61+
frame.circuit.append(cirq_op.on_each(qubits))
62+
return ()
63+
64+
@impl(clifford.stmts.CX)
65+
@impl(clifford.stmts.CY)
66+
@impl(clifford.stmts.CZ)
67+
def control(
68+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.ControlledGate
69+
):
70+
controls = frame.get(stmt.controls)
71+
targets = frame.get(stmt.targets)
72+
cirq_op = getattr(cirq, stmt.name.upper())
73+
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
74+
frame.circuit.append(cirq_op.on_each(cirq_qubits))
75+
return ()
76+
77+
@impl(clifford.stmts.Rx)
78+
@impl(clifford.stmts.Ry)
79+
@impl(clifford.stmts.Rz)
80+
def rot(
81+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: clifford.stmts.RotationGate
82+
):
83+
qubits = frame.get(stmt.qubits)
84+
85+
turns = frame.get(stmt.angle)
86+
angle = turns * 2 * math.pi
87+
cirq_op = getattr(cirq, stmt.name.title())(rads=angle)
88+
89+
frame.circuit.append(cirq_op.on_each(qubits))
90+
return ()
Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,53 @@
11
import cirq
2-
from kirin.emit import EmitError
32
from kirin.interp import MethodTable, impl
43

54
from bloqade.squin import noise
65

76
from .base import EmitCirq, EmitCirqFrame
8-
from .runtime import (
9-
KronRuntime,
10-
BasicOpRuntime,
11-
OperatorRuntimeABC,
12-
PauliStringRuntime,
13-
)
147

158

169
@noise.dialect.register(key="emit.cirq")
17-
class EmitCirqNoiseMethods(MethodTable):
10+
class __EmitCirqNoiseMethods(MethodTable):
11+
12+
two_qubit_paulis = (
13+
"IX",
14+
"IY",
15+
"IZ",
16+
"XI",
17+
"XX",
18+
"XY",
19+
"XZ",
20+
"YI",
21+
"YX",
22+
"YY",
23+
"YZ",
24+
"ZI",
25+
"ZX",
26+
"ZY",
27+
"ZZ",
28+
)
1829

1930
@impl(noise.stmts.Depolarize)
2031
def depolarize(
2132
self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize
2233
):
2334
p = frame.get(stmt.p)
24-
gate = cirq.depolarize(p, n_qubits=1)
25-
return (BasicOpRuntime(gate=gate),)
35+
qubits = frame.get(stmt.qubits)
36+
cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits)
37+
frame.circuit.append(cirfq_op)
38+
return ()
2639

2740
@impl(noise.stmts.Depolarize2)
2841
def depolarize2(
2942
self, interp: EmitCirq, frame: EmitCirqFrame, stmt: noise.stmts.Depolarize2
3043
):
3144
p = frame.get(stmt.p)
32-
gate = cirq.depolarize(p, n_qubits=2)
33-
return (BasicOpRuntime(gate=gate),)
45+
controls = frame.get(stmt.controls)
46+
targets = frame.get(stmt.targets)
47+
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
48+
cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits)
49+
frame.circuit.append(cirq_op)
50+
return ()
3451

3552
@impl(noise.stmts.SingleQubitPauliChannel)
3653
def single_qubit_pauli_channel(
@@ -39,9 +56,15 @@ def single_qubit_pauli_channel(
3956
frame: EmitCirqFrame,
4057
stmt: noise.stmts.SingleQubitPauliChannel,
4158
):
42-
ps = frame.get(stmt.params)
43-
gate = cirq.asymmetric_depolarize(*ps)
44-
return (BasicOpRuntime(gate=gate),)
59+
px = frame.get(stmt.px)
60+
py = frame.get(stmt.py)
61+
pz = frame.get(stmt.pz)
62+
qubits = frame.get(stmt.qubits)
63+
64+
cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits)
65+
frame.circuit.append(cirq_op)
66+
67+
return ()
4568

4669
@impl(noise.stmts.TwoQubitPauliChannel)
4770
def two_qubit_pauli_channel(
@@ -50,33 +73,18 @@ def two_qubit_pauli_channel(
5073
frame: EmitCirqFrame,
5174
stmt: noise.stmts.TwoQubitPauliChannel,
5275
):
53-
ps = frame.get(stmt.params)
54-
paulis = ("I", "X", "Y", "Z")
55-
pauli_combinations = [
56-
pauli1 + pauli2
57-
for pauli1 in paulis
58-
for pauli2 in paulis
59-
if not (pauli1 == pauli2 == "I")
60-
]
61-
error_probabilities = {key: p for (key, p) in zip(pauli_combinations, ps)}
62-
gate = cirq.asymmetric_depolarize(error_probabilities=error_probabilities)
63-
return (BasicOpRuntime(gate),)
64-
65-
@staticmethod
66-
def _op_to_key(operator: OperatorRuntimeABC) -> str:
67-
match operator:
68-
case KronRuntime():
69-
key_lhs = EmitCirqNoiseMethods._op_to_key(operator.lhs)
70-
key_rhs = EmitCirqNoiseMethods._op_to_key(operator.rhs)
71-
return key_lhs + key_rhs
76+
ps = frame.get(stmt.probabilities)
77+
error_probabilities = {
78+
key: p for (key, p) in zip(self.two_qubit_paulis, ps) if p != 0
79+
}
7280

73-
case BasicOpRuntime():
74-
return str(operator.gate)
81+
controls = frame.get(stmt.controls)
82+
targets = frame.get(stmt.targets)
83+
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
7584

76-
case PauliStringRuntime():
77-
return operator.string
85+
cirq_op = cirq.asymmetric_depolarize(
86+
error_probabilities=error_probabilities
87+
).on_each(cirq_qubits)
88+
frame.circuit.append(cirq_op)
7889

79-
case _:
80-
raise EmitError(
81-
f"Unexpected operator runtime in StochasticUnitaryChannel of type {type(operator).__name__} encountered!"
82-
)
90+
return ()

src/bloqade/cirq_utils/emit/qubit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
1414
n_qubits = frame.get(stmt.n_qubits)
1515

1616
if frame.qubits is not None:
17-
cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)]
17+
cirq_qubits = tuple(
18+
frame.qubits[i + frame.qubit_index] for i in range(n_qubits)
19+
)
1820
else:
19-
cirq_qubits = [
21+
cirq_qubits = tuple(
2022
cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
21-
]
23+
)
2224

2325
frame.qubit_index += n_qubits
2426
return (cirq_qubits,)

0 commit comments

Comments
 (0)