Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/bloqade/squin/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
) # custom lowering required for wrapper to work here


# Carry over from Qubit dialect
@statement(dialect=dialect)
class Broadcast(ir.Statement):
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
operator: ir.SSAValue = info.argument(OpType)
inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)

def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
result_types = tuple(WireType for _ in args)
super().__init__(
args=(operator,) + args,
result_types=result_types,
args_slice={
"operator": 0,
"inputs": slice(1, None),
}, # pretty printing + syntax sugar
) # custom lowering required for wrapper to work here


# NOTE: measurement cannot be pure because they will collapse the state
# of the qubit. The state is a hidden state that is not visible to
# the user in the wire dialect.
Expand Down Expand Up @@ -98,6 +117,7 @@ class Reset(ir.Statement):
class ConstPropWire(interp.MethodTable):

@interp.impl(Apply)
@interp.impl(Broadcast)
def apply(self, interp, frame, stmt: Apply):

return frame.get_values(stmt.inputs)
110 changes: 110 additions & 0 deletions test/squin/test_constprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# There's a method table in the wire dialect statements
# that handles the multiple return values from Apply and Broadcast
# that can cause problems for constant propoagation's default implementation.
# These tests just make sure there are corresponding lattice types per each
# SSA value (as opposed to a bunch of "missing" entries despite multiple
# return values from Broadcast and Apply)

from kirin import ir, types
from kirin.passes import Fold
from kirin.dialects import py, func
from kirin.analysis.const import Propagate

from bloqade import squin


def as_int(value: int):
return py.constant.Constant(value=value)


def as_float(value: float):
return py.constant.Constant(value=value)


def gen_func_from_stmts(stmts):

squin_with_py = squin.groups.wired.add(py)

block = ir.Block(stmts)
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
func_wrapper = func.Function(
sym_name="main",
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
body=ir.Region(blocks=block),
)

constructed_method = ir.Method(
mod=None,
py_func=None,
sym_name="main",
dialects=squin_with_py,
code=func_wrapper,
arg_names=[],
)

fold_pass = Fold(squin_with_py)
fold_pass(constructed_method)

return constructed_method


def test_wire_apply_constprop():

stmts = [
(n_qubits := as_int(2)),
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
(idx0 := as_int(0)),
(q0 := py.GetItem(qreg.result, idx0.result)),
(idx1 := as_int(1)),
(q1 := py.GetItem(qreg.result, idx1.result)),
# get wires
(w0 := squin.wire.Unwrap(q0.result)),
(w1 := squin.wire.Unwrap(q1.result)),
# put wire through gates
(x := squin.op.stmts.X()),
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
(a := squin.wire.Apply(cx.result, w0.result, w1.result)),
(func.Return(a.results[0])),
]
constructed_method = gen_func_from_stmts(stmts)

prop_analysis = Propagate(constructed_method.dialects)
frame, _ = prop_analysis.run_analysis(constructed_method)

assert len(frame.entries.values()) == 13


def test_wire_broadcast_constprop():

stmts = [
(n_qubits := as_int(4)),
(qreg := squin.qubit.New(n_qubits=n_qubits.result)),
(idx0 := as_int(0)),
(q0 := py.GetItem(qreg.result, idx0.result)),
(idx1 := as_int(1)),
(q1 := py.GetItem(qreg.result, idx1.result)),
(idx2 := as_int(2)),
(q2 := py.GetItem(qreg.result, idx2.result)),
(idx3 := as_int(3)),
(q3 := py.GetItem(qreg.result, idx3.result)),
# get wires
(w0 := squin.wire.Unwrap(q0.result)),
(w1 := squin.wire.Unwrap(q1.result)),
(w2 := squin.wire.Unwrap(q2.result)),
(w3 := squin.wire.Unwrap(q3.result)),
# put wire through gates
(x := squin.op.stmts.X()),
(cx := squin.op.stmts.Control(op=x.result, n_controls=1)),
(
a := squin.wire.Broadcast(
cx.result, w0.result, w1.result, w2.result, w3.result
)
),
(func.Return(a.results[0])),
]
constructed_method = gen_func_from_stmts(stmts)

prop_analysis = Propagate(constructed_method.dialects)
frame, _ = prop_analysis.run_analysis(constructed_method)

assert len(frame.entries.values()) == 21