Skip to content

Commit badf341

Browse files
authored
Add broadcast to wire (#220)
This adds `broadcast` from the qubit dialect into the `wire` dialect as well for squin. I added unit tests specifically for constprop as well and took @weinbe58 's feedback to add things into the constprop impl. I'll be very honest, I just carried over my work from #217 because I think I did some improper merge and it broke something somewhere else 😅 and the CI failing but my local run working was annoying me too much.
1 parent 0d37bf2 commit badf341

File tree

2 files changed

+130
-0
lines changed

2 files changed

+130
-0
lines changed

src/bloqade/squin/wire.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
6868
) # custom lowering required for wrapper to work here
6969

7070

71+
# Carry over from Qubit dialect
72+
@statement(dialect=dialect)
73+
class Broadcast(ir.Statement):
74+
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
75+
operator: ir.SSAValue = info.argument(OpType)
76+
inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)
77+
78+
def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
79+
result_types = tuple(WireType for _ in args)
80+
super().__init__(
81+
args=(operator,) + args,
82+
result_types=result_types,
83+
args_slice={
84+
"operator": 0,
85+
"inputs": slice(1, None),
86+
}, # pretty printing + syntax sugar
87+
) # custom lowering required for wrapper to work here
88+
89+
7190
# NOTE: measurement cannot be pure because they will collapse the state
7291
# of the qubit. The state is a hidden state that is not visible to
7392
# the user in the wire dialect.
@@ -98,6 +117,7 @@ class Reset(ir.Statement):
98117
class ConstPropWire(interp.MethodTable):
99118

100119
@interp.impl(Apply)
120+
@interp.impl(Broadcast)
101121
def apply(self, interp, frame, stmt: Apply):
102122

103123
return frame.get_values(stmt.inputs)

test/squin/test_constprop.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# There's a method table in the wire dialect statements
2+
# that handles the multiple return values from Apply and Broadcast
3+
# that can cause problems for constant propoagation's default implementation.
4+
# These tests just make sure there are corresponding lattice types per each
5+
# SSA value (as opposed to a bunch of "missing" entries despite multiple
6+
# return values from Broadcast and Apply)
7+
8+
from kirin import ir, types
9+
from kirin.passes import Fold
10+
from kirin.dialects import py, func
11+
from kirin.analysis.const import Propagate
12+
13+
from bloqade import squin
14+
15+
16+
def as_int(value: int):
17+
return py.constant.Constant(value=value)
18+
19+
20+
def as_float(value: float):
21+
return py.constant.Constant(value=value)
22+
23+
24+
def gen_func_from_stmts(stmts):
25+
26+
squin_with_py = squin.groups.wired.add(py)
27+
28+
block = ir.Block(stmts)
29+
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
30+
func_wrapper = func.Function(
31+
sym_name="main",
32+
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
33+
body=ir.Region(blocks=block),
34+
)
35+
36+
constructed_method = ir.Method(
37+
mod=None,
38+
py_func=None,
39+
sym_name="main",
40+
dialects=squin_with_py,
41+
code=func_wrapper,
42+
arg_names=[],
43+
)
44+
45+
fold_pass = Fold(squin_with_py)
46+
fold_pass(constructed_method)
47+
48+
return constructed_method
49+
50+
51+
def test_wire_apply_constprop():
52+
53+
stmts = [
54+
(n_qubits := as_int(2)),
55+
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
56+
(idx0 := as_int(0)),
57+
(q0 := py.GetItem(qreg.result, idx0.result)),
58+
(idx1 := as_int(1)),
59+
(q1 := py.GetItem(qreg.result, idx1.result)),
60+
# get wires
61+
(w0 := squin.wire.Unwrap(q0.result)),
62+
(w1 := squin.wire.Unwrap(q1.result)),
63+
# put wire through gates
64+
(x := squin.op.stmts.X()),
65+
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
66+
(a := squin.wire.Apply(cx.result, w0.result, w1.result)),
67+
(func.Return(a.results[0])),
68+
]
69+
constructed_method = gen_func_from_stmts(stmts)
70+
71+
prop_analysis = Propagate(constructed_method.dialects)
72+
frame, _ = prop_analysis.run_analysis(constructed_method)
73+
74+
assert len(frame.entries.values()) == 13
75+
76+
77+
def test_wire_broadcast_constprop():
78+
79+
stmts = [
80+
(n_qubits := as_int(4)),
81+
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
82+
(idx0 := as_int(0)),
83+
(q0 := py.GetItem(qreg.result, idx0.result)),
84+
(idx1 := as_int(1)),
85+
(q1 := py.GetItem(qreg.result, idx1.result)),
86+
(idx2 := as_int(2)),
87+
(q2 := py.GetItem(qreg.result, idx2.result)),
88+
(idx3 := as_int(3)),
89+
(q3 := py.GetItem(qreg.result, idx3.result)),
90+
# get wires
91+
(w0 := squin.wire.Unwrap(q0.result)),
92+
(w1 := squin.wire.Unwrap(q1.result)),
93+
(w2 := squin.wire.Unwrap(q2.result)),
94+
(w3 := squin.wire.Unwrap(q3.result)),
95+
# put wire through gates
96+
(x := squin.op.stmts.X()),
97+
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
98+
(
99+
a := squin.wire.Broadcast(
100+
cx.result, w0.result, w1.result, w2.result, w3.result
101+
)
102+
),
103+
(func.Return(a.results[0])),
104+
]
105+
constructed_method = gen_func_from_stmts(stmts)
106+
107+
prop_analysis = Propagate(constructed_method.dialects)
108+
frame, _ = prop_analysis.run_analysis(constructed_method)
109+
110+
assert len(frame.entries.values()) == 21

0 commit comments

Comments
 (0)