Skip to content

Commit e76b3f3

Browse files
committed
support for control gates confirmed
1 parent e55a8fb commit e76b3f3

File tree

3 files changed

+163
-46
lines changed

3 files changed

+163
-46
lines changed

src/bloqade/squin/analysis/nsites/impls.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from kirin import ir, interp
44

5-
from bloqade.squin import op
5+
from bloqade.squin import op, wire
66

77
from .lattice import (
88
NoSites,
@@ -11,6 +11,15 @@
1111
from .analysis import NSitesAnalysis
1212

1313

14+
@wire.dialect.register(key="op.nsites")
15+
class SquinWire(interp.MethodTable):
16+
17+
@interp.impl(wire.Apply)
18+
def apply(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.Apply):
19+
20+
return tuple([frame.get(input) for input in stmt.inputs])
21+
22+
1423
@op.dialect.register(key="op.nsites")
1524
class SquinOp(interp.MethodTable):
1625

src/bloqade/squin/rewrite/stim.py

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,46 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator):
102102
return stim.gate.H
103103
case op.stmts.S():
104104
return stim.gate.S
105+
case op.stmts.Identity(): # enforce sites defined = num wires in
106+
return stim.gate.Identity
105107
case _:
106108
return None
107109

110+
# get the qubit indices from the Apply statement argument
111+
# wires/qubits
112+
def insert_qubit_idx_ssa(
113+
self, apply_stmt: wire.Apply | qubit.Apply
114+
) -> tuple[ir.SSAValue, ...]:
115+
116+
if isinstance(apply_stmt, qubit.Apply):
117+
qubits = apply_stmt.qubits
118+
address_attribute: AddressAttribute = self.get_address(qubits)
119+
# Should get an AddressTuple out of the address stored in attribute
120+
address_tuple = address_attribute.address
121+
qubit_idx_ssas: list[ir.SSAValue] = []
122+
for address_qubit in address_tuple.data:
123+
qubit_idx = address_qubit.data
124+
qubit_idx_stmt = py.Constant(qubit_idx)
125+
qubit_idx_stmt.insert_before(apply_stmt)
126+
qubit_idx_ssas.append(qubit_idx_stmt.result)
127+
128+
return tuple(qubit_idx_ssas)
129+
130+
elif isinstance(apply_stmt, wire.Apply):
131+
wire_ssas = apply_stmt.inputs
132+
qubit_idx_ssas: list[ir.SSAValue] = []
133+
for wire_ssa in wire_ssas:
134+
address_attribute = self.get_address(wire_ssa)
135+
# get parent qubit idx
136+
wire_address = address_attribute.address
137+
qubit_idx = wire_address.origin_qubit.data
138+
qubit_idx_stmt = py.Constant(qubit_idx)
139+
# accumulate all qubit idx SSA to instantiate stim gate stmt
140+
qubit_idx_ssas.append(qubit_idx_stmt.result)
141+
qubit_idx_stmt.insert_before(apply_stmt)
142+
143+
return tuple(qubit_idx_ssas)
144+
108145
# might be worth attempting multiple dispatch like qasm2 rewrites
109146
# for Glob and Parallel to UOp
110147
# The problem is I'd have to introduce names for all the statements
@@ -142,51 +179,60 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
142179
# this is an SSAValue, need it to be the actual operator
143180
applied_op = apply_stmt.operator.owner
144181

145-
# need to handle Identity and Control through separate means
146-
# but we can handle X, Y, Z, and H here just fine
147-
stim_1q_op = self.get_stim_1q_gate(applied_op)
148-
149-
if isinstance(apply_stmt, qubit.Apply):
150-
qubits = apply_stmt.qubits
151-
address_attribute: AddressAttribute = self.get_address(qubits)
152-
# Should get an AddressTuple out of the address stored in attribute
153-
address_tuple = address_attribute.address
154-
qubit_idx_ssas: list[ir.SSAValue] = []
155-
for address_qubit in address_tuple.data:
156-
qubit_idx = address_qubit.data
157-
qubit_idx_stmt = py.Constant(qubit_idx)
158-
qubit_idx_ssas.append(qubit_idx_stmt.result)
159-
qubit_idx_stmt.insert_before(apply_stmt)
160-
161-
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
162-
163-
# can't do any of this because of dependencies downstream
164-
# apply_stmt.replace_by(stim_1q_stmt)
165-
166-
return RewriteResult(has_done_something=True)
182+
if isinstance(applied_op, op.stmts.Control):
183+
return self.rewrite_Control(apply_stmt)
167184

168-
elif isinstance(apply_stmt, wire.Apply):
169-
wires_ssa = apply_stmt.inputs
170-
qubit_idx_ssas: list[ir.SSAValue] = []
171-
for wire_ssa in wires_ssa:
172-
address_attribute = self.get_address(wire_ssa)
173-
# get parent qubit idx
174-
wire_address = address_attribute.address
175-
qubit_idx = wire_address.origin_qubit.data
176-
qubit_idx_stmt = py.Constant(qubit_idx)
177-
# accumulate all qubit idx SSA to instantiate stim gate stmt
178-
qubit_idx_ssas.append(qubit_idx_stmt.result)
179-
qubit_idx_stmt.insert_before(apply_stmt)
185+
# need to handle Control through separate means
186+
# but we can handle X, Y, Z, H, and S here just fine
187+
stim_1q_op = self.get_stim_1q_gate(applied_op)
180188

181-
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
182-
stim_1q_stmt.insert_before(apply_stmt)
189+
qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt)
190+
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
191+
stim_1q_stmt.insert_before(apply_stmt)
183192

184-
# There is something depending on the results of the statement,
185-
# need to handle that so replacement/deletion can occur without problems
193+
return RewriteResult(has_done_something=True)
186194

187-
# apply's results become wires that go to other apply's/wrap stmts
188-
# apply_stmt.replace_by(stim_1q_stmt)
195+
def rewrite_Control(
196+
self, apply_stmt_ctrl: qubit.Apply | wire.Apply
197+
) -> RewriteResult:
198+
# stim only supports CX, CY, CZ so we have to check the
199+
# operator of Apply is a Control gate, enforce it's only asking for 1 control qubit,
200+
# and that the target of the control is X, Y, Z in squin
201+
202+
ctrl_op: op.stmts.Control = apply_stmt_ctrl.operator.owner
203+
# enforce that n_controls is 1
204+
205+
ctrl_op_target_gate = ctrl_op.op.owner
206+
207+
# should enforce that this is some multiple of 2
208+
qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt_ctrl)
209+
# according to stim, final result can be:
210+
# CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4)
211+
target_qubits = []
212+
ctrl_qubits = []
213+
# definitely a better way to do this but
214+
# can't think of it right now
215+
for i in range(len(qubit_idx_ssas)):
216+
if (i % 2) == 0:
217+
ctrl_qubits.append(qubit_idx_ssas[i])
218+
else:
219+
target_qubits.append(qubit_idx_ssas[i])
220+
221+
target_qubits = tuple(target_qubits)
222+
ctrl_qubits = tuple(ctrl_qubits)
223+
224+
match ctrl_op_target_gate:
225+
case op.stmts.X():
226+
stim_stmt = stim.CX(controls=ctrl_qubits, targets=target_qubits)
227+
case op.stmts.Y():
228+
stim_stmt = stim.CY(controls=ctrl_qubits, targets=target_qubits)
229+
case op.stmts.Z():
230+
stim_stmt = stim.CZ(controls=ctrl_qubits, targets=target_qubits)
231+
case _:
232+
raise NotImplementedError(
233+
"Control gates beyond CX, CY, and CZ are not supported"
234+
)
189235

190-
return RewriteResult(has_done_something=True)
236+
stim_stmt.insert_before(apply_stmt_ctrl)
191237

192-
return RewriteResult()
238+
return RewriteResult(has_done_something=True)

test/squin/stim/stim.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,14 @@ def test_1q():
5858
# pass the wires through some 1 Qubit operators
5959
(op1 := squin.op.stmts.S()),
6060
(op2 := squin.op.stmts.H()),
61-
(op3 := squin.op.stmts.X()),
61+
(op3 := squin.op.stmts.Identity(sites=1)),
62+
(op4 := squin.op.stmts.Identity(sites=1)),
6263
(v0 := squin.wire.Apply(op1.result, w0.result)),
6364
(v1 := squin.wire.Apply(op2.result, v0.results[0])),
6465
(v2 := squin.wire.Apply(op3.result, v1.results[0])),
66+
(v3 := squin.wire.Apply(op4.result, v2.results[0])),
6567
(
66-
squin.wire.Wrap(v2.results[0], q0.result)
68+
squin.wire.Wrap(v3.results[0], q0.result)
6769
), # for wrap, just free a use for the result SSAval
6870
(ret_none := func.ConstantNone()),
6971
(func.Return(ret_none)),
@@ -107,4 +109,64 @@ def test_1q():
107109
constructed_method.print()
108110

109111

110-
test_1q()
112+
def test_control():
113+
114+
stmts: list[ir.Statement] = [
115+
# Create qubit register
116+
(n_qubits := as_int(2)),
117+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
118+
# Get qubis out
119+
(idx0 := as_int(0)),
120+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
121+
(idx1 := as_int(1)),
122+
(q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)),
123+
# Unwrap to get wires
124+
(w0 := squin.wire.Unwrap(qubit=q0.result)),
125+
(w1 := squin.wire.Unwrap(qubit=q1.result)),
126+
# set up control gate
127+
(op1 := squin.op.stmts.X()),
128+
(cx := squin.op.stmts.Control(op1.result, n_controls=1)),
129+
(app := squin.wire.Apply(cx.result, w0.result, w1.result)),
130+
# wrap things back
131+
(squin.wire.Wrap(wire=app.results[0], qubit=q0.result)),
132+
(squin.wire.Wrap(wire=app.results[1], qubit=q1.result)),
133+
(ret_none := func.ConstantNone()),
134+
(func.Return(ret_none)),
135+
]
136+
137+
constructed_method = gen_func_from_stmts(stmts)
138+
constructed_method.print()
139+
140+
address_frame, _ = address.AddressAnalysis(
141+
constructed_method.dialects
142+
).run_analysis(constructed_method, no_raise=False)
143+
144+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
145+
constructed_method, no_raise=False
146+
)
147+
148+
constructed_method.print(analysis=address_frame.entries)
149+
constructed_method.print(analysis=nsites_frame.entries)
150+
151+
wrap_squin_analysis = WrapSquinAnalysis(
152+
address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries
153+
)
154+
fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis))
155+
rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code)
156+
157+
# attempt rewrite to Stim
158+
# Be careful with Fixpoint, can go to infinity until reaches defined threshold
159+
squin_to_stim = Walk(SquinToStim())
160+
rewrite_res = squin_to_stim.rewrite(constructed_method.code)
161+
162+
constructed_method.print()
163+
164+
# Get rid of the unused statements
165+
dce = Fixpoint(Walk(DeadCodeElimination()))
166+
rewrite_res = dce.rewrite(constructed_method.code)
167+
print(rewrite_res)
168+
169+
constructed_method.print()
170+
171+
172+
test_control()

0 commit comments

Comments
 (0)