From 52e88288e4cdff12d9b640833180a9caebc54d69 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 29 Apr 2025 09:08:58 -0400 Subject: [PATCH 1/4] introduce class hiearchy, add broadcast to wire successfully --- src/bloqade/squin/qubit.py | 25 ++++++++++++++----------- src/bloqade/squin/wire.py | 20 ++++++++++++++++---- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 91646a6d..f13ff22d 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -20,25 +20,28 @@ dialect = ir.Dialect("squin.qubit") -@statement(dialect=dialect) -class New(ir.Statement): +@statement +class MultiQubitStatement(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - n_qubits: ir.SSAValue = info.argument(types.Int) - result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any]) + operator: ir.SSAValue = info.argument(OpType) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) @statement(dialect=dialect) -class Apply(ir.Statement): - traits = frozenset({lowering.FromPythonCall()}) - operator: ir.SSAValue = info.argument(OpType) - qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) +class Apply(MultiQubitStatement): + pass @statement(dialect=dialect) -class Broadcast(ir.Statement): +class Broadcast(MultiQubitStatement): + pass + + +@statement(dialect=dialect) +class New(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - operator: ir.SSAValue = info.argument(OpType) - qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) + n_qubits: ir.SSAValue = info.argument(types.Int) + result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any]) @statement(dialect=dialect) diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index ec46c957..0fb8be3d 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -48,10 +48,8 @@ class Unwrap(ir.Statement): result: ir.ResultValue = info.result(WireType) -# In Quake, you put a wire in and get a wire out when you "apply" an operator -# In this case though we just need to indicate that an operator is applied to list[wires] -@statement(dialect=dialect) -class Apply(ir.Statement): # apply(op, w1, w2, ...) +@statement +class MultiWireStatement(ir.Statement): traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) operator: ir.SSAValue = info.argument(OpType) inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) @@ -68,6 +66,20 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): ) # custom lowering required for wrapper to work here +# In Quake, you put a wire in and get a wire out when you "apply" an operator +# In this case though we just need to indicate that an operator is applied to list[wires] +@statement(dialect=dialect) +class Apply(MultiWireStatement): # apply(op, w1, w2, ...) + def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): + super().__init__(operator, *args) + + +@statement(dialect=dialect) +class Broadcast(MultiWireStatement): + def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): + super().__init__(operator, *args) + + # NOTE: measurement cannot be pure because they will collapse the state # of the qubit. The state is a hidden state that is not visible to # the user in the wire dialect. From b2f1238d57260bfa04b0ebbdac1106bd4c325d2d Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 29 Apr 2025 13:20:48 -0400 Subject: [PATCH 2/4] remove subclasses --- src/bloqade/squin/qubit.py | 15 ++++++--------- src/bloqade/squin/wire.py | 28 +++++++++++++++++----------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index f13ff22d..22856ee6 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -20,21 +20,18 @@ dialect = ir.Dialect("squin.qubit") -@statement -class MultiQubitStatement(ir.Statement): +@statement(dialect=dialect) +class Apply(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) operator: ir.SSAValue = info.argument(OpType) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) @statement(dialect=dialect) -class Apply(MultiQubitStatement): - pass - - -@statement(dialect=dialect) -class Broadcast(MultiQubitStatement): - pass +class Broadcast(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) + operator: ir.SSAValue = info.argument(OpType) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) @statement(dialect=dialect) diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index 0fb8be3d..8ca0aae1 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -48,8 +48,10 @@ class Unwrap(ir.Statement): result: ir.ResultValue = info.result(WireType) -@statement -class MultiWireStatement(ir.Statement): +# In Quake, you put a wire in and get a wire out when you "apply" an operator +# In this case though we just need to indicate that an operator is applied to list[wires] +@statement(dialect=dialect) +class Apply(ir.Statement): traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) operator: ir.SSAValue = info.argument(OpType) inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) @@ -66,18 +68,22 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): ) # custom lowering required for wrapper to work here -# In Quake, you put a wire in and get a wire out when you "apply" an operator -# In this case though we just need to indicate that an operator is applied to list[wires] @statement(dialect=dialect) -class Apply(MultiWireStatement): # apply(op, w1, w2, ...) - def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): - super().__init__(operator, *args) - +class Broadcast(ir.Statement): + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) + operator: ir.SSAValue = info.argument(OpType) + inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) -@statement(dialect=dialect) -class Broadcast(MultiWireStatement): def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): - super().__init__(operator, *args) + result_types = tuple(WireType for _ in args) + super().__init__( + args=(operator,) + args, + result_types=result_types, # result types of the Apply statement, should all be WireTypes + args_slice={ + "operator": 0, + "inputs": slice(1, None), + }, # pretty printing + syntax sugar + ) # custom lowering required for wrapper to work here # NOTE: measurement cannot be pure because they will collapse the state From 31f2852599ea853c38b3210e125c8220aa53df81 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 30 Apr 2025 10:52:30 -0400 Subject: [PATCH 3/4] add unit tests, add impl per Phillip's request --- src/bloqade/squin/wire.py | 3 + test/squin/analysis/test_constprop_wire.py | 110 +++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 test/squin/analysis/test_constprop_wire.py diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index 8ca0aae1..71cca7df 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -112,10 +112,13 @@ class Reset(ir.Statement): # Issue where constant propagation can't handle # multiple return values from Apply properly + + @dialect.register(key="constprop") class ConstPropWire(interp.MethodTable): @interp.impl(Apply) + @interp.impl(Broadcast) def apply(self, interp, frame, stmt: Apply): return frame.get_values(stmt.inputs) diff --git a/test/squin/analysis/test_constprop_wire.py b/test/squin/analysis/test_constprop_wire.py new file mode 100644 index 00000000..7355363b --- /dev/null +++ b/test/squin/analysis/test_constprop_wire.py @@ -0,0 +1,110 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func +from kirin.analysis.const import Propagate + +from bloqade import squin + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts): + + squin_with_py = squin.groups.wired.add(py) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=squin.op.types.OpType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_py, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_py) + fold_pass(constructed_method) + + return constructed_method + + +def test_constprop_wire_apply(): + + # generate method + + stmts = [ + (n_qubits := as_int(2)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + # get wire + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + # This triggers missing in prop analysis! + (vs := squin.wire.Apply(cx.result, w0.result, w1.result)), + (func.Return(vs.results[0])), + ] + + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + assert len(frame.entries.values()) == 13 + + +def test_constprop_wire_broadcast(): + + # generate method + + stmts = [ + (n_qubits := as_int(4)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q0 := py.GetItem(qreg.result, idx0.result)), + (idx1 := as_int(1)), + (q1 := py.GetItem(qreg.result, idx1.result)), + (idx2 := as_int(2)), + (q2 := py.GetItem(qreg.result, idx2.result)), + (idx3 := as_int(3)), + (q3 := py.GetItem(qreg.result, idx3.result)), + # get wires + (w0 := squin.wire.Unwrap(q0.result)), + (w1 := squin.wire.Unwrap(q1.result)), + (w2 := squin.wire.Unwrap(q2.result)), + (w3 := squin.wire.Unwrap(q3.result)), + # put wire through gates + (x := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op=x.result, n_controls=1)), + # This triggers missing in prop analysis! + ( + vs := squin.wire.Broadcast( + cx.result, w0.result, w1.result, w2.result, w3.result + ) + ), + (func.Return(vs.results[0])), + ] + + constructed_method = gen_func_from_stmts(stmts) + + prop_analysis = Propagate(constructed_method.dialects) + frame, _ = prop_analysis.run_analysis(constructed_method) + + assert len(frame.entries.values()) == 21 From 668a98c4e04076d040334afd53951b56ea4aa3dd Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 30 Apr 2025 14:50:26 -0400 Subject: [PATCH 4/4] fix failing measure sugar tests --- src/bloqade/squin/qubit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 4817ea0c..7a63559a 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -41,6 +41,7 @@ class New(ir.Statement): result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any]) +@statement(dialect=dialect) class MeasureAny(ir.Statement): name = "measure"