From 6c9be647a9b824f9da07a212548d25bfa9f76a77 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 30 Apr 2025 15:46:40 -0400 Subject: [PATCH 1/2] broadcast in wire dialect along with constprop tests --- src/bloqade/squin/wire.py | 20 +++++++ test/squin/test_constprop.py | 113 +++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 test/squin/test_constprop.py diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index ec46c957..f2a70d0d 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -68,6 +68,25 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): ) # custom lowering required for wrapper to work here +# Carry over from Qubit dialect +@statement(dialect=dialect) +class Broadcast(ir.Statement): + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) + operator: ir.SSAValue = info.argument(OpType) + inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) + + def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): + result_types = tuple(WireType for _ in args) + super().__init__( + args=(operator,) + args, + result_types=result_types, + args_slice={ + "operator": 0, + "inputs": slice(1, None), + }, # pretty printing + syntax sugar + ) # custom lowering required for wrapper to work here + + # NOTE: measurement cannot be pure because they will collapse the state # of the qubit. The state is a hidden state that is not visible to # the user in the wire dialect. @@ -98,6 +117,7 @@ class Reset(ir.Statement): class ConstPropWire(interp.MethodTable): @interp.impl(Apply) + @interp.impl(Broadcast) def apply(self, interp, frame, stmt: Apply): return frame.get_values(stmt.inputs) diff --git a/test/squin/test_constprop.py b/test/squin/test_constprop.py new file mode 100644 index 00000000..17e73753 --- /dev/null +++ b/test/squin/test_constprop.py @@ -0,0 +1,113 @@ +# There's a method table in the wire dialect statements +# that handles the multiple return values from Apply and Broadcast +# that can cause problems for constant propoagation's default implementation. +# These tests just make sure there are corresponding lattice types per each +# SSA value (as opposed to a bunch of "missing" entries despite multiple +# return values from Broadcast and Apply) + +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func +from kirin.analysis.const import Propagate + +from bloqade import squin + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts): + + squin_with_py = squin.groups.wired.add(py) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=squin.op.types.OpType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_py, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_py) + fold_pass(constructed_method) + + return constructed_method + + +def test_wire_apply_constprop(): + + stmts = [ + (n_qubits := as_int(2)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + # get wires + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + (a := squin.wire.Apply(cx.result, w0.result, w1.result)), + (func.Return(a.results[0])), + ] + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + + assert len(frame.entries.values()) == 13 + + +def test_wire_broadcast_constprop(): + + stmts = [ + (n_qubits := as_int(4)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + (idx2 := as_int(2)), + (q2 := py.GetItem(qreg.result, idx2.result)), + (idx3 := as_int(3)), + (q3 := py.GetItem(qreg.result, idx3.result)), + # get wires + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + (w2 := squin.wire.Unwrap(q2.result)), + (w3 := squin.wire.Unwrap(q3.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + ( + a := squin.wire.Broadcast( + cx.result, w0.result, w1.result, w2.result, w3.result + ) + ), + (func.Return(a.results[0])), + ] + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + + assert len(frame.entries.values()) == 21 + + +test_wire_broadcast_constprop() From 1a3589553f88697670b8e80e29d51eb00c721e42 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 30 Apr 2025 16:10:54 -0400 Subject: [PATCH 2/2] remove test invocation --- test/squin/test_constprop.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/squin/test_constprop.py b/test/squin/test_constprop.py index 17e73753..baf94e21 100644 --- a/test/squin/test_constprop.py +++ b/test/squin/test_constprop.py @@ -108,6 +108,3 @@ def test_wire_broadcast_constprop(): frame, _ = prop_analysis.run_analysis(constructed_method) assert len(frame.entries.values()) == 21 - - -test_wire_broadcast_constprop()