diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 6d6c53c6..7a63559a 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -20,13 +20,6 @@ dialect = ir.Dialect("squin.qubit") -@statement(dialect=dialect) -class New(ir.Statement): - traits = frozenset({lowering.FromPythonCall()}) - n_qubits: ir.SSAValue = info.argument(types.Int) - result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any]) - - @statement(dialect=dialect) class Apply(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) @@ -41,6 +34,13 @@ class Broadcast(ir.Statement): qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) +@statement(dialect=dialect) +class New(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) + n_qubits: ir.SSAValue = info.argument(types.Int) + result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any]) + + @statement(dialect=dialect) class MeasureAny(ir.Statement): name = "measure" @@ -62,7 +62,6 @@ class MeasureQubit(ir.Statement): @statement(dialect=dialect) class MeasureQubitList(ir.Statement): name = "measure.qubit.list" - traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) result: ir.ResultValue = info.result(ilist.IListType[types.Bool]) diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index ec46c957..71cca7df 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -51,7 +51,25 @@ class Unwrap(ir.Statement): # In Quake, you put a wire in and get a wire out when you "apply" an operator # In this case though we just need to indicate that an operator is applied to list[wires] @statement(dialect=dialect) -class Apply(ir.Statement): # apply(op, w1, w2, ...) +class Apply(ir.Statement): + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) + operator: ir.SSAValue = info.argument(OpType) + inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) + + def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): + result_types = tuple(WireType for _ in args) + super().__init__( + args=(operator,) + args, + result_types=result_types, # result types of the Apply statement, should all be WireTypes + args_slice={ + "operator": 0, + "inputs": slice(1, None), + }, # pretty printing + syntax sugar + ) # custom lowering required for wrapper to work here + + +@statement(dialect=dialect) +class Broadcast(ir.Statement): traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) operator: ir.SSAValue = info.argument(OpType) inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) @@ -94,10 +112,13 @@ class Reset(ir.Statement): # Issue where constant propagation can't handle # multiple return values from Apply properly + + @dialect.register(key="constprop") class ConstPropWire(interp.MethodTable): @interp.impl(Apply) + @interp.impl(Broadcast) def apply(self, interp, frame, stmt: Apply): return frame.get_values(stmt.inputs) diff --git a/test/squin/analysis/test_constprop_wire.py b/test/squin/analysis/test_constprop_wire.py new file mode 100644 index 00000000..7355363b --- /dev/null +++ b/test/squin/analysis/test_constprop_wire.py @@ -0,0 +1,110 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func +from kirin.analysis.const import Propagate + +from bloqade import squin + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts): + + squin_with_py = squin.groups.wired.add(py) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=squin.op.types.OpType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_py, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_py) + fold_pass(constructed_method) + + return constructed_method + + +def test_constprop_wire_apply(): + + # generate method + + stmts = [ + (n_qubits := as_int(2)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + # get wire + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + # This triggers missing in prop analysis! + (vs := squin.wire.Apply(cx.result, w0.result, w1.result)), + (func.Return(vs.results[0])), + ] + + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + assert len(frame.entries.values()) == 13 + + +def test_constprop_wire_broadcast(): + + # generate method + + stmts = [ + (n_qubits := as_int(4)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + (idx2 := as_int(2)), + (q2 := py.GetItem(qreg.result, idx2.result)), + (idx3 := as_int(3)), + (q3 := py.GetItem(qreg.result, idx3.result)), + # get wires + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + (w2 := squin.wire.Unwrap(q2.result)), + (w3 := squin.wire.Unwrap(q3.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + # This triggers missing in prop analysis! + ( + vs := squin.wire.Broadcast( + cx.result, w0.result, w1.result, w2.result, w3.result + ) + ), + (func.Return(vs.results[0])), + ] + + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + + assert len(frame.entries.values()) == 21