Skip to content

Commit 31f2852

Browse files
committed
add unit tests, add impl per Phillip's request
1 parent 17db0d4 commit 31f2852

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

src/bloqade/squin/wire.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,13 @@ class Reset(ir.Statement):
112112

113113
# Issue where constant propagation can't handle
114114
# multiple return values from Apply properly
115+
116+
115117
@dialect.register(key="constprop")
116118
class ConstPropWire(interp.MethodTable):
117119

118120
@interp.impl(Apply)
121+
@interp.impl(Broadcast)
119122
def apply(self, interp, frame, stmt: Apply):
120123

121124
return frame.get_values(stmt.inputs)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from kirin import ir, types
2+
from kirin.passes import Fold
3+
from kirin.dialects import py, func
4+
from kirin.analysis.const import Propagate
5+
6+
from bloqade import squin
7+
8+
9+
def as_int(value: int):
10+
return py.constant.Constant(value=value)
11+
12+
13+
def as_float(value: float):
14+
return py.constant.Constant(value=value)
15+
16+
17+
def gen_func_from_stmts(stmts):
18+
19+
squin_with_py = squin.groups.wired.add(py)
20+
21+
block = ir.Block(stmts)
22+
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
23+
func_wrapper = func.Function(
24+
sym_name="main",
25+
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
26+
body=ir.Region(blocks=block),
27+
)
28+
29+
constructed_method = ir.Method(
30+
mod=None,
31+
py_func=None,
32+
sym_name="main",
33+
dialects=squin_with_py,
34+
code=func_wrapper,
35+
arg_names=[],
36+
)
37+
38+
fold_pass = Fold(squin_with_py)
39+
fold_pass(constructed_method)
40+
41+
return constructed_method
42+
43+
44+
def test_constprop_wire_apply():
45+
46+
# generate method
47+
48+
stmts = [
49+
(n_qubits := as_int(2)),
50+
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
51+
(idx0 := as_int(0)),
52+
(q0 := py.GetItem(qreg.result, idx0.result)),
53+
(idx1 := as_int(1)),
54+
(q1 := py.GetItem(qreg.result, idx1.result)),
55+
# get wire
56+
(w0 := squin.wire.Unwrap(q0.result)),
57+
(w1 := squin.wire.Unwrap(q1.result)),
58+
# put wire through gates
59+
(x := squin.op.stmts.X()),
60+
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
61+
# This triggers missing in prop analysis!
62+
(vs := squin.wire.Apply(cx.result, w0.result, w1.result)),
63+
(func.Return(vs.results[0])),
64+
]
65+
66+
constructed_method = gen_func_from_stmts(stmts)
67+
68+
prop_analysis = Propagate(constructed_method.dialects)
69+
frame, _ = prop_analysis.run_analysis(constructed_method)
70+
assert len(frame.entries.values()) == 13
71+
72+
73+
def test_constprop_wire_broadcast():
74+
75+
# generate method
76+
77+
stmts = [
78+
(n_qubits := as_int(4)),
79+
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
80+
(idx0 := as_int(0)),
81+
(q0 := py.GetItem(qreg.result, idx0.result)),
82+
(idx1 := as_int(1)),
83+
(q1 := py.GetItem(qreg.result, idx1.result)),
84+
(idx2 := as_int(2)),
85+
(q2 := py.GetItem(qreg.result, idx2.result)),
86+
(idx3 := as_int(3)),
87+
(q3 := py.GetItem(qreg.result, idx3.result)),
88+
# get wires
89+
(w0 := squin.wire.Unwrap(q0.result)),
90+
(w1 := squin.wire.Unwrap(q1.result)),
91+
(w2 := squin.wire.Unwrap(q2.result)),
92+
(w3 := squin.wire.Unwrap(q3.result)),
93+
# put wire through gates
94+
(x := squin.op.stmts.X()),
95+
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
96+
# This triggers missing in prop analysis!
97+
(
98+
vs := squin.wire.Broadcast(
99+
cx.result, w0.result, w1.result, w2.result, w3.result
100+
)
101+
),
102+
(func.Return(vs.results[0])),
103+
]
104+
105+
constructed_method = gen_func_from_stmts(stmts)
106+
107+
prop_analysis = Propagate(constructed_method.dialects)
108+
frame, _ = prop_analysis.run_analysis(constructed_method)
109+
110+
assert len(frame.entries.values()) == 21

0 commit comments

Comments
 (0)