diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index ec46c957..f2a70d0d 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -68,6 +68,25 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): ) # custom lowering required for wrapper to work here +# Carry over from Qubit dialect +@statement(dialect=dialect) +class Broadcast(ir.Statement): + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) + operator: ir.SSAValue = info.argument(OpType) + inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) + + def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): + result_types = tuple(WireType for _ in args) + super().__init__( + args=(operator,) + args, + result_types=result_types, + args_slice={ + "operator": 0, + "inputs": slice(1, None), + }, # pretty printing + syntax sugar + ) # custom lowering required for wrapper to work here + + # NOTE: measurement cannot be pure because they will collapse the state # of the qubit. The state is a hidden state that is not visible to # the user in the wire dialect. @@ -98,6 +117,7 @@ class Reset(ir.Statement): class ConstPropWire(interp.MethodTable): @interp.impl(Apply) + @interp.impl(Broadcast) def apply(self, interp, frame, stmt: Apply): return frame.get_values(stmt.inputs) diff --git a/test/squin/test_constprop.py b/test/squin/test_constprop.py new file mode 100644 index 00000000..baf94e21 --- /dev/null +++ b/test/squin/test_constprop.py @@ -0,0 +1,110 @@ +# There's a method table in the wire dialect statements +# that handles the multiple return values from Apply and Broadcast +# that can cause problems for constant propoagation's default implementation. +# These tests just make sure there are corresponding lattice types per each +# SSA value (as opposed to a bunch of "missing" entries despite multiple +# return values from Broadcast and Apply) + +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func +from kirin.analysis.const import Propagate + +from bloqade import squin + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts): + + squin_with_py = squin.groups.wired.add(py) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=squin.op.types.OpType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_py, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_py) + fold_pass(constructed_method) + + return constructed_method + + +def test_wire_apply_constprop(): + + stmts = [ + (n_qubits := as_int(2)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + # get wires + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + (a := squin.wire.Apply(cx.result, w0.result, w1.result)), + (func.Return(a.results[0])), + ] + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + + assert len(frame.entries.values()) == 13 + + +def test_wire_broadcast_constprop(): + + stmts = [ + (n_qubits := as_int(4)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + (idx2 := as_int(2)), + (q2 := py.GetItem(qreg.result, idx2.result)), + (idx3 := as_int(3)), + (q3 := py.GetItem(qreg.result, idx3.result)), + # get wires + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + (w2 := squin.wire.Unwrap(q2.result)), + (w3 := squin.wire.Unwrap(q3.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + ( + a := squin.wire.Broadcast( + cx.result, w0.result, w1.result, w2.result, w3.result + ) + ), + (func.Return(a.results[0])), + ] + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + + assert len(frame.entries.values()) == 21