Skip to content

Commit 8b7c6cd

Browse files
committed
Allow passing in and returning the quantum register when loading cirq.Circuits into squin (#313)
While looking into this I realized that the only thing that makes sense (IMO) to pass into or return from a circuit is the list of qubits. I added two simple keyword arguments to control the behavior. The result is that you can compose kernels that you can easily compose kernels that you generate from circuits. For example: ```python q = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q)) get_entangled_qubits = squin.cirq.load_circuit( circuit, return_register=True, kernel_name="get_entangled_qubits" ) get_entangled_qubits.print() entangle_qubits = squin.cirq.load_circuit( circuit, register_as_argument=True, kernel_name="entangle_qubits" ) @squin.kernel def main(): qreg = get_entangled_qubits() qreg2 = squin.qubit.new(1) entangle_qubits([qreg[1], qreg2[0]]) return squin.qubit.measure(qreg2) ``` Here, `get_entangled_qubits` allocates a new register, entangles it and returns the result, whereas `entangle_qubits` accepts a register of two qubits to entangle. Of course, you could also pass in and return from the same kernel by setting both `return_register=True` and `register_as_argument=True`. FYI, @jon-wurtz , let me know if this covers the use case you had in mind. Closes #302
1 parent ca45199 commit 8b7c6cd

File tree

3 files changed

+168
-27
lines changed

3 files changed

+168
-27
lines changed

src/bloqade/squin/cirq/__init__.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ def load_circuit(
1818
circuit: cirq.Circuit,
1919
kernel_name: str = "main",
2020
dialects: ir.DialectGroup = kernel,
21+
register_as_argument: bool = False,
22+
return_register: bool = False,
23+
register_argument_name: str = "q",
2124
globals: dict[str, Any] | None = None,
2225
file: str | None = None,
2326
lineno_offset: int = 0,
@@ -32,13 +35,23 @@ def load_circuit(
3235
Keyword Args:
3336
kernel_name (str): The name of the kernel to load. Defaults to "main".
3437
dialects (ir.DialectGroup | None): The dialects to use. Defaults to `squin.kernel`.
38+
register_as_argument (bool): Determine whether the resulting kernel function should accept
39+
a single `ilist.IList[Qubit, Any]` argument that is a list of qubits used within the
40+
function. This allows you to compose kernel functions generated from circuits.
41+
Defaults to `False`.
42+
return_register (bool): Determine whether the resulting kernel functionr returns a
43+
single value of type `ilist.IList[Qubit, Any]` that is the list of qubits used
44+
in the kernel function. Useful when you want to compose multiple kernel functions
45+
generated from circuits. Defaults to `False`.
46+
register_argument_name (str): The name of the argument that represents the qubit register.
47+
Only used when `register_as_argument=True`. Defaults to "q".
3548
globals (dict[str, Any] | None): The global variables to use. Defaults to None.
3649
file (str | None): The file name for error reporting. Defaults to None.
3750
lineno_offset (int): The line number offset for error reporting. Defaults to 0.
3851
col_offset (int): The column number offset for error reporting. Defaults to 0.
3952
compactify (bool): Whether to compactify the output. Defaults to True.
4053
41-
Example:
54+
## Usage Examples:
4255
4356
```python
4457
# from cirq's "hello qubit" example
@@ -60,6 +73,30 @@ def load_circuit(
6073
# print the resulting IR
6174
main.print()
6275
```
76+
77+
You can also compose kernel functions generated from circuits by passing in
78+
and / or returning the respective quantum registers:
79+
80+
```python
81+
q = cirq.LineQubit.range(2)
82+
circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q))
83+
84+
get_entangled_qubits = squin.cirq.load_circuit(
85+
circuit, return_register=True, kernel_name="get_entangled_qubits"
86+
)
87+
get_entangled_qubits.print()
88+
89+
entangle_qubits = squin.cirq.load_circuit(
90+
circuit, register_as_argument=True, kernel_name="entangle_qubits"
91+
)
92+
93+
@squin.kernel
94+
def main():
95+
qreg = get_entangled_qubits()
96+
qreg2 = squin.qubit.new(1)
97+
entangle_qubits([qreg[1], qreg2[0]])
98+
return squin.qubit.measure(qreg2)
99+
```
63100
"""
64101

65102
target = Squin(dialects=dialects, circuit=circuit)
@@ -71,24 +108,46 @@ def load_circuit(
71108
lineno_offset=lineno_offset,
72109
col_offset=col_offset,
73110
compactify=compactify,
111+
register_as_argument=register_as_argument,
112+
register_argument_name=register_argument_name,
74113
)
75114

76-
# NOTE: no return value
77-
return_value = func.ConstantNone()
78-
body.blocks[0].stmts.append(return_value)
79-
body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
115+
if return_register:
116+
return_value = target.qreg
117+
else:
118+
return_value = func.ConstantNone()
119+
body.blocks[0].stmts.append(return_value)
120+
121+
return_node = func.Return(value_or_stmt=return_value)
122+
body.blocks[0].stmts.append(return_node)
123+
124+
self_arg_name = kernel_name + "_self"
125+
arg_names = [self_arg_name]
126+
if register_as_argument:
127+
args = (target.qreg.type,)
128+
arg_names.append(register_argument_name)
129+
else:
130+
args = ()
131+
132+
# NOTE: add _self as argument; need to know signature before so do it after lowering
133+
signature = func.Signature(args, return_node.value.type)
134+
body.blocks[0].args.insert_from(
135+
0,
136+
types.Generic(ir.Method, types.Tuple.where(signature.inputs), signature.output),
137+
self_arg_name,
138+
)
80139

81140
code = func.Function(
82141
sym_name=kernel_name,
83-
signature=func.Signature((), types.NoneType),
142+
signature=signature,
84143
body=body,
85144
)
86145

87146
return ir.Method(
88147
mod=None,
89148
py_func=None,
90149
sym_name=kernel_name,
91-
arg_names=[],
150+
arg_names=arg_names,
92151
dialects=dialects,
93152
code=code,
94153
)

src/bloqade/squin/cirq/lowering.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import field, dataclass
44

55
import cirq
6-
from kirin import ir, lowering
6+
from kirin import ir, types, lowering
77
from kirin.rewrite import Walk, CFGCompactify
88
from kirin.dialects import py, scf, ilist
99

@@ -25,20 +25,19 @@ class Squin(lowering.LoweringABC[CirqNode]):
2525
"""Lower a cirq.Circuit object to a squin kernel"""
2626

2727
circuit: cirq.Circuit
28-
qreg: qubit.New = field(init=False)
28+
qreg: ir.SSAValue = field(init=False)
2929
qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict)
3030
next_qreg_index: int = field(init=False, default=0)
3131

32-
def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
33-
index = self.qreg_index.get(qid)
34-
35-
if index is None:
36-
index = self.next_qreg_index
37-
self.qreg_index[qid] = index
38-
self.next_qreg_index += 1
32+
def __post_init__(self):
33+
# TODO: sort by cirq ordering
34+
qbits = sorted(self.circuit.all_qubits())
35+
self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)}
3936

37+
def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
38+
index = self.qreg_index[qid]
4039
index_ssa = state.current_frame.push(py.Constant(index)).result
41-
qbit_getitem = state.current_frame.push(py.GetItem(self.qreg.result, index_ssa))
40+
qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa))
4241
return qbit_getitem.result
4342

4443
def lower_qubit_getindices(
@@ -64,6 +63,8 @@ def run(
6463
lineno_offset: int = 0,
6564
col_offset: int = 0,
6665
compactify: bool = True,
66+
register_as_argument: bool = False,
67+
register_argument_name: str = "q",
6768
) -> ir.Region:
6869

6970
state = lowering.State(
@@ -73,16 +74,21 @@ def run(
7374
col_offset=col_offset,
7475
)
7576

76-
with state.frame(
77-
[stmt],
78-
globals=globals,
79-
finalize_next=False,
80-
) as frame:
81-
# NOTE: create a global register of qubits first
82-
# TODO: can there be a circuit without qubits?
83-
n_qubits = cirq.num_qubits(self.circuit)
84-
n = frame.push(py.Constant(n_qubits))
85-
self.qreg = frame.push(qubit.New(n_qubits=n.result))
77+
with state.frame([stmt], globals=globals, finalize_next=False) as frame:
78+
79+
# NOTE: need a register of qubits before lowering statements
80+
if register_as_argument:
81+
# NOTE: register as argument to the kernel; we have freedom of choice for the name here
82+
frame.curr_block.args.append_from(
83+
ilist.IListType[qubit.QubitType, types.Any],
84+
name=register_argument_name,
85+
)
86+
self.qreg = frame.curr_block.args[0]
87+
else:
88+
# NOTE: create a new register of appropriate size
89+
n_qubits = len(self.qreg_index)
90+
n = frame.push(py.Constant(n_qubits))
91+
self.qreg = frame.push(qubit.New(n_qubits=n.result)).result
8692

8793
self.visit(state, stmt)
8894

test/squin/cirq/test_cirq_to_squin.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import cirq
44
import pytest
5+
from kirin import types
6+
from kirin.passes import inline
7+
from kirin.dialects import ilist
58

69
from bloqade import squin
710
from bloqade.pyqrack import DynamicMemorySimulator
@@ -124,6 +127,19 @@ def noise_channels():
124127
)
125128

126129

130+
def nested_circuit():
131+
q = cirq.LineQubit.range(3)
132+
133+
return cirq.Circuit(
134+
cirq.H(q[0]),
135+
cirq.CircuitOperation(
136+
cirq.Circuit(cirq.H(q[1]), cirq.CX(q[1], q[2])).freeze(),
137+
use_repetition_ids=False,
138+
).controlled_by(q[0]),
139+
cirq.measure(*q),
140+
)
141+
142+
127143
@pytest.mark.parametrize(
128144
"circuit_f",
129145
[
@@ -161,6 +177,66 @@ def test_circuit(circuit_f, run_sim: bool = False):
161177
print(ket)
162178

163179

180+
def test_return_register():
181+
circuit = basic_circuit()
182+
kernel = squin.load_circuit(circuit, return_register=True)
183+
kernel.print()
184+
185+
assert isinstance(kernel.return_type, types.Generic)
186+
assert kernel.return_type.body.is_subseteq(ilist.IListType)
187+
188+
189+
@pytest.mark.xfail
190+
def test_nested_circuit():
191+
# TODO: lowering for CircuitOperation
192+
test_circuit(nested_circuit)
193+
194+
195+
def test_passing_in_register():
196+
circuit = pow_gate_circuit()
197+
print(circuit)
198+
kernel = squin.cirq.load_circuit(circuit, register_as_argument=True)
199+
kernel.print()
200+
201+
202+
def test_passing_and_returning_register():
203+
circuit = pow_gate_circuit()
204+
print(circuit)
205+
kernel = squin.cirq.load_circuit(
206+
circuit, register_as_argument=True, return_register=True
207+
)
208+
kernel.print()
209+
210+
211+
def test_nesting_lowered_circuit():
212+
q = cirq.LineQubit.range(2)
213+
circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q))
214+
215+
get_entangled_qubits = squin.cirq.load_circuit(
216+
circuit, return_register=True, kernel_name="get_entangled_qubits"
217+
)
218+
get_entangled_qubits.print()
219+
220+
entangle_qubits = squin.cirq.load_circuit(
221+
circuit, register_as_argument=True, kernel_name="entangle_qubits"
222+
)
223+
224+
@squin.kernel
225+
def main():
226+
qreg = get_entangled_qubits()
227+
qreg2 = squin.qubit.new(1)
228+
entangle_qubits([qreg[1], qreg2[0]])
229+
return squin.qubit.measure(qreg2)
230+
231+
# if you get up to here, the validation works
232+
main.print()
233+
234+
# inline to see if the IR is correct
235+
inline.InlinePass(main.dialects)(main)
236+
237+
main.print()
238+
239+
164240
def test_classical_control(run_sim: bool = False):
165241
q = cirq.LineQubit.range(2)
166242
circuit = cirq.Circuit(

0 commit comments

Comments
 (0)