Skip to content

Commit e55a8fb

Browse files
committed
preliminary handling of Apply
1 parent 7c70b5e commit e55a8fb

File tree

3 files changed

+142
-6
lines changed

3 files changed

+142
-6
lines changed

src/bloqade/squin/rewrite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .stim import (
2+
SquinToStim as SquinToStim,
23
SitesAttribute as SitesAttribute,
34
AddressAttribute as AddressAttribute,
45
WrapSquinAnalysis as WrapSquinAnalysis,

src/bloqade/squin/rewrite/stim.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,27 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator):
115115
# constants, seems to be more for lowering from Python AST
116116

117117
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
118-
pass
118+
119+
match node:
120+
case wire.Apply() | qubit.Apply():
121+
return self.rewrite_Apply(node)
122+
case wire.Wrap():
123+
return self.rewrite_Wrap(node)
124+
case _:
125+
return RewriteResult()
126+
127+
return RewriteResult()
128+
129+
def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult:
130+
131+
# get the wire going into the statement
132+
wire_ssa = wrap_stmt.wire
133+
# remove the wrap statement altogether, then the wire that went into it
134+
wrap_stmt.delete()
135+
wire_ssa.delete()
136+
137+
# do NOT want to delete the qubit SSA! Leave that alone!
138+
return RewriteResult(has_done_something=True)
119139

120140
def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
121141

@@ -140,8 +160,8 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
140160

141161
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
142162

143-
apply_stmt.replace_by(stim_1q_stmt)
144-
apply_stmt.delete()
163+
# can't do any of this because of dependencies downstream
164+
# apply_stmt.replace_by(stim_1q_stmt)
145165

146166
return RewriteResult(has_done_something=True)
147167

@@ -151,16 +171,21 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
151171
for wire_ssa in wires_ssa:
152172
address_attribute = self.get_address(wire_ssa)
153173
# get parent qubit idx
154-
wire_address = address_attribute.data
174+
wire_address = address_attribute.address
155175
qubit_idx = wire_address.origin_qubit.data
156176
qubit_idx_stmt = py.Constant(qubit_idx)
177+
# accumulate all qubit idx SSA to instantiate stim gate stmt
157178
qubit_idx_ssas.append(qubit_idx_stmt.result)
158179
qubit_idx_stmt.insert_before(apply_stmt)
159180

160181
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
182+
stim_1q_stmt.insert_before(apply_stmt)
183+
184+
# There is something depending on the results of the statement,
185+
# need to handle that so replacement/deletion can occur without problems
161186

162-
apply_stmt.replace_by(stim_1q_stmt)
163-
apply_stmt.delete()
187+
# apply's results become wires that go to other apply's/wrap stmts
188+
# apply_stmt.replace_by(stim_1q_stmt)
164189

165190
return RewriteResult(has_done_something=True)
166191

test/squin/stim/stim.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from kirin import ir, types
2+
from kirin.passes import Fold
3+
from kirin.rewrite import Walk, Fixpoint, DeadCodeElimination
4+
from kirin.dialects import py, func, ilist
5+
6+
from bloqade import qasm2, squin
7+
from bloqade.analysis import address
8+
from bloqade.squin.rewrite import SquinToStim, WrapSquinAnalysis
9+
from bloqade.squin.analysis import nsites
10+
11+
12+
def as_int(value: int):
13+
return py.constant.Constant(value=value)
14+
15+
16+
def as_float(value: float):
17+
return py.constant.Constant(value=value)
18+
19+
20+
def gen_func_from_stmts(stmts):
21+
22+
extended_dialect = squin.groups.wired.add(qasm2.core).add(ilist)
23+
24+
block = ir.Block(stmts)
25+
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
26+
func_wrapper = func.Function(
27+
sym_name="main",
28+
signature=func.Signature(inputs=(), output=types.NoneType),
29+
body=ir.Region(blocks=block),
30+
)
31+
32+
constructed_method = ir.Method(
33+
mod=None,
34+
py_func=None,
35+
sym_name="main",
36+
dialects=extended_dialect,
37+
code=func_wrapper,
38+
arg_names=[],
39+
)
40+
41+
fold_pass = Fold(extended_dialect)
42+
fold_pass(constructed_method)
43+
44+
return constructed_method
45+
46+
47+
def test_1q():
48+
49+
stmts: list[ir.Statement] = [
50+
# Create qubit register
51+
(n_qubits := as_int(1)),
52+
(qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)),
53+
# Get qubit out
54+
(idx0 := as_int(0)),
55+
(q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)),
56+
# Unwrap to get wires
57+
(w0 := squin.wire.Unwrap(qubit=q0.result)),
58+
# pass the wires through some 1 Qubit operators
59+
(op1 := squin.op.stmts.S()),
60+
(op2 := squin.op.stmts.H()),
61+
(op3 := squin.op.stmts.X()),
62+
(v0 := squin.wire.Apply(op1.result, w0.result)),
63+
(v1 := squin.wire.Apply(op2.result, v0.results[0])),
64+
(v2 := squin.wire.Apply(op3.result, v1.results[0])),
65+
(
66+
squin.wire.Wrap(v2.results[0], q0.result)
67+
), # for wrap, just free a use for the result SSAval
68+
(ret_none := func.ConstantNone()),
69+
(func.Return(ret_none)),
70+
# the fact I return a wire here means DCE will NOT go ahead and
71+
# eliminate all the other wire.Apply stmts
72+
]
73+
74+
constructed_method = gen_func_from_stmts(stmts)
75+
76+
constructed_method.print()
77+
78+
address_frame, _ = address.AddressAnalysis(
79+
constructed_method.dialects
80+
).run_analysis(constructed_method, no_raise=False)
81+
82+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
83+
constructed_method, no_raise=False
84+
)
85+
86+
constructed_method.print(analysis=address_frame.entries)
87+
constructed_method.print(analysis=nsites_frame.entries)
88+
89+
# attempt to wrap analysis results
90+
91+
wrap_squin_analysis = WrapSquinAnalysis(
92+
address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries
93+
)
94+
fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis))
95+
rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code)
96+
97+
# attempt rewrite to Stim
98+
# Be careful with Fixpoint, can go to infinity until reaches defined threshold
99+
squin_to_stim = Walk(SquinToStim())
100+
rewrite_res = squin_to_stim.rewrite(constructed_method.code)
101+
102+
# Get rid of the unused statements
103+
dce = Fixpoint(Walk(DeadCodeElimination()))
104+
rewrite_res = dce.rewrite(constructed_method.code)
105+
print(rewrite_res)
106+
107+
constructed_method.print()
108+
109+
110+
test_1q()

0 commit comments

Comments
 (0)