diff --git a/src/bloqade/analysis/address/__init__.py b/src/bloqade/analysis/address/__init__.py index 00826475..c004b7cc 100644 --- a/src/bloqade/analysis/address/__init__.py +++ b/src/bloqade/analysis/address/__init__.py @@ -4,6 +4,7 @@ NotQubit as NotQubit, AddressReg as AddressReg, AnyAddress as AnyAddress, + AddressWire as AddressWire, AddressQubit as AddressQubit, AddressTuple as AddressTuple, ) diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 9ce84f83..d1438b00 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -11,7 +11,9 @@ class AddressAnalysis(Forward[Address]): - """This analysis pass can be used to track the global addresses of qubits.""" + """ + This analysis pass can be used to track the global addresses of qubits and wires. + """ keys = ["qubit.address"] lattice = Address diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index 483034e1..a9ae40e8 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -6,7 +6,16 @@ from kirin.analysis import ForwardFrame, const from kirin.dialects import cf, py, scf, func, ilist -from .lattice import Address, NotQubit, AddressReg, AddressQubit, AddressTuple +from bloqade import squin + +from .lattice import ( + Address, + NotQubit, + AddressReg, + AddressWire, + AddressQubit, + AddressTuple, +) from .analysis import AddressAnalysis @@ -64,10 +73,16 @@ def new_ilist( class PyIndexing(interp.MethodTable): @interp.impl(py.GetItem) def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem): + # Integer index into the thing being indexed idx = interp.get_const_value(int, stmt.index) + # The object being indexed into obj = frame.get(stmt.obj) + # The `data` attributes holds onto other Address types + # so we just extract that here if isinstance(obj, AddressTuple): return (obj.data[idx],) + # an AddressReg is guaranteed to just have some sequence + # of integers which is directly pluggable to AddressQubit elif isinstance(obj, AddressReg): return (AddressQubit(obj.data[idx]),) else: @@ -147,3 +162,67 @@ def for_loop( return # if terminate is Return, there is no result return loop_vars + + +# Address lattice elements we can work with: +## NotQubit (bottom), AnyAddress (top) + +## AddressTuple -> data: tuple[Address, ...] +### Recursive type, could contain itself or other variants +### This pops up in cases where you can have an IList/Tuple +### That contains elements that could be other Address types + +## AddressReg -> data: Sequence[int] +### specific to creation of a register of qubits + +## AddressQubit -> data: int +### Base qubit address type + + +@squin.wire.dialect.register(key="qubit.address") +class SquinWireMethodTable(interp.MethodTable): + + @interp.impl(squin.wire.Unwrap) + def unwrap( + self, + interp_: AddressAnalysis, + frame: ForwardFrame[Address], + stmt: squin.wire.Unwrap, + ): + + origin_qubit = frame.get(stmt.qubit) + + return (AddressWire(origin_qubit=origin_qubit),) + + @interp.impl(squin.wire.Apply) + def apply( + self, + interp_: AddressAnalysis, + frame: ForwardFrame[Address], + stmt: squin.wire.Apply, + ): + + origin_qubits = tuple( + [frame.get(input_elem).origin_qubit for input_elem in stmt.inputs] + ) + new_address_wires = tuple( + [AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits] + ) + return new_address_wires + + +@squin.qubit.dialect.register(key="qubit.address") +class SquinQubitMethodTable(interp.MethodTable): + + # This can be treated like a QRegNew impl + @interp.impl(squin.qubit.New) + def new( + self, + interp_: AddressAnalysis, + frame: ForwardFrame[Address], + stmt: squin.qubit.New, + ): + n_qubits = interp_.get_const_value(int, stmt.n_qubits) + addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits)) + interp_.next_address += n_qubits + return (addr,) diff --git a/src/bloqade/analysis/address/lattice.py b/src/bloqade/analysis/address/lattice.py index 4c523ad4..57c772ba 100644 --- a/src/bloqade/analysis/address/lattice.py +++ b/src/bloqade/analysis/address/lattice.py @@ -72,3 +72,14 @@ def is_subseteq(self, other: Address) -> bool: if isinstance(other, AddressQubit): return self.data == other.data return False + + +@final +@dataclass +class AddressWire(Address): + origin_qubit: AddressQubit + + def is_subseteq(self, other: Address) -> bool: + if isinstance(other, AddressWire): + return self.origin_qubit == self.origin_qubit + return False diff --git a/src/bloqade/qasm2/dialects/glob.py b/src/bloqade/qasm2/dialects/glob.py index f7d6c758..eb9f308e 100644 --- a/src/bloqade/qasm2/dialects/glob.py +++ b/src/bloqade/qasm2/dialects/glob.py @@ -5,7 +5,7 @@ from bloqade.qasm2.parse import ast from bloqade.qasm2.types import QRegType from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame -from bloqade.analysis.schedule import DagScheduleAnalysis +from bloqade.squin.analysis.schedule import DagScheduleAnalysis dialect = ir.Dialect("qasm2.glob") diff --git a/src/bloqade/qasm2/dialects/parallel.py b/src/bloqade/qasm2/dialects/parallel.py index 4f64aa84..8313e572 100644 --- a/src/bloqade/qasm2/dialects/parallel.py +++ b/src/bloqade/qasm2/dialects/parallel.py @@ -8,7 +8,7 @@ from bloqade.qasm2.parse import ast from bloqade.qasm2.types import QubitType from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame -from bloqade.analysis.schedule import DagScheduleAnalysis +from bloqade.squin.analysis.schedule import DagScheduleAnalysis dialect = ir.Dialect("qasm2.parallel") diff --git a/src/bloqade/qasm2/dialects/uop/schedule.py b/src/bloqade/qasm2/dialects/uop/schedule.py index 15955ec4..67f6c282 100644 --- a/src/bloqade/qasm2/dialects/uop/schedule.py +++ b/src/bloqade/qasm2/dialects/uop/schedule.py @@ -1,7 +1,7 @@ from kirin import interp from kirin.analysis import ForwardFrame -from bloqade.analysis.schedule import DagScheduleAnalysis +from bloqade.squin.analysis.schedule import DagScheduleAnalysis from . import stmts from ._dialect import dialect diff --git a/src/bloqade/qasm2/passes/parallel.py b/src/bloqade/qasm2/passes/parallel.py index 0fcd415b..8dd285d8 100644 --- a/src/bloqade/qasm2/passes/parallel.py +++ b/src/bloqade/qasm2/passes/parallel.py @@ -20,7 +20,7 @@ ) from kirin.analysis import const -from bloqade.analysis import address, schedule +from bloqade.analysis import address from bloqade.qasm2.rewrite import ( MergePolicyABC, ParallelToUOpRule, @@ -28,6 +28,7 @@ UOpToParallelRule, SimpleOptimalMergePolicy, ) +from bloqade.squin.analysis import schedule @dataclass diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index e8a3be72..aa7d5bb7 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -9,7 +9,7 @@ from bloqade.analysis import address from bloqade.qasm2.dialects import uop, core, parallel -from bloqade.analysis.schedule import StmtDag +from bloqade.squin.analysis.schedule import StmtDag class MergePolicyABC(abc.ABC): diff --git a/src/bloqade/squin/analysis/__init__.py b/src/bloqade/squin/analysis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/analysis/schedule.py b/src/bloqade/squin/analysis/schedule.py similarity index 100% rename from src/bloqade/analysis/schedule.py rename to src/bloqade/squin/analysis/schedule.py diff --git a/src/bloqade/squin/op/complex.py b/src/bloqade/squin/op/complex.py new file mode 100644 index 00000000..10e0d630 --- /dev/null +++ b/src/bloqade/squin/op/complex.py @@ -0,0 +1,6 @@ +# Stopgap Measure, squin dialect needs Complex type but +# this is only available in Kirin 0.15.x + +from kirin.ir.attrs.types import PyClass + +Complex = PyClass(complex) diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index 3a5d2b44..09fb6052 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -3,6 +3,7 @@ from .types import OpType from .traits import Sized, HasSize, Unitary, MaybeUnitary +from .complex import Complex from ._dialect import dialect @@ -53,7 +54,7 @@ class Scale(CompositeOp): traits = frozenset({ir.Pure(), ir.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) - factor: ir.SSAValue = info.argument(types.Complex) + factor: ir.SSAValue = info.argument(Complex) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index 9578181d..07e41728 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -6,13 +6,17 @@ dialect. """ -from kirin import ir, types +from kirin import ir, types, interp from kirin.decl import info, statement from bloqade.types import QubitType from .op.types import OpType +# from kirin.lowering import wraps + +# from .op.types import Op, OpType + dialect = ir.Dialect("squin.wire") @@ -35,6 +39,8 @@ class Wrap(ir.Statement): qubit: ir.SSAValue = info.argument(QubitType) +# "Unwrap the quantum references to expose wires" -> From Quake Dialect documentation +# Unwrap(Qubit) -> Wire @statement(dialect=dialect) class Unwrap(ir.Statement): traits = frozenset({ir.FromPythonCall(), ir.Pure()}) @@ -42,19 +48,24 @@ class Unwrap(ir.Statement): result: ir.ResultValue = info.result(WireType) +# In Quake, you put a wire in and get a wire out when you "apply" an operator +# In this case though we just need to indicate that an operator is applied to list[wires] @statement(dialect=dialect) -class Apply(ir.Statement): +class Apply(ir.Statement): # apply(op, w1, w2, ...) traits = frozenset({ir.FromPythonCall(), ir.Pure()}) operator: ir.SSAValue = info.argument(OpType) - inputs: tuple[ir.SSAValue] = info.argument(WireType) + inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): result_types = tuple(WireType for _ in args) super().__init__( args=(operator,) + args, - result_types=result_types, - args_slice={"operator": 0, "inputs": slice(1, None)}, - ) + result_types=result_types, # result types of the Apply statement, should all be WireTypes + args_slice={ + "operator": 0, + "inputs": slice(1, None), + }, # pretty printing + syntax sugar + ) # custom lowering required for wrapper to work here # NOTE: measurement cannot be pure because they will collapse the state @@ -79,3 +90,14 @@ class MeasureAndReset(ir.Statement): class Reset(ir.Statement): traits = frozenset({ir.FromPythonCall(), WireTerminator()}) wire: ir.SSAValue = info.argument(WireType) + + +# Issue where constant propagation can't handle +# multiple return values from Apply properly +@dialect.register(key="constprop") +class ConstPropWire(interp.MethodTable): + + @interp.impl(Apply) + def apply(self, interp, frame, stmt: Apply): + + return frame.get_values(stmt.inputs) diff --git a/test/analysis/address/test_analysis.py b/test/analysis/address/test_analysis.py new file mode 100644 index 00000000..0d205b8d --- /dev/null +++ b/test/analysis/address/test_analysis.py @@ -0,0 +1,202 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func, ilist + +from bloqade import qasm2, squin +from bloqade.analysis import address + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +squin_with_qasm_core = squin.groups.wired.add(qasm2.core).add(ilist) + + +def test_unwrap(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(2)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + # Unwrap to get wires + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + # Put them in an ilist and return to prevent elimination + (wire_list := ilist.New([w1.result, w2.result])), + (func.Return(wire_list)), + ] + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=ilist.IListType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_qasm_core, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_qasm_core) + fold_pass(constructed_method) + + frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + address_wires = [] + address_types = frame.entries.values() # dict[SSAValue, Address] + for address_type in address_types: + if isinstance(address_type, address.AddressWire): + address_wires.append(address_type) + + # 2 AddressWires should be produced from the Analysis + assert len(address_wires) == 2 + # The AddressWires should have qubits 0 and 1 as their parents + for qubit_idx, address_wire in enumerate(address_wires): + assert qubit_idx == address_wire.origin_qubit.data + + +## test unwrap + pass through single statements +def test_multiple_unwrap(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(2)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + # pass the wires through some 1 Qubit operators + (op1 := squin.op.stmts.T()), + (op2 := squin.op.stmts.H()), + (op3 := squin.op.stmts.X()), + (v0 := squin.wire.Apply(op1.result, w0.result)), + (v1 := squin.wire.Apply(op2.result, v0.results[0])), + (v2 := squin.wire.Apply(op3.result, w1.result)), + (wire_list := ilist.New([v1.results[0], v2.results[0]])), + (func.Return(wire_list)), + ] + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=ilist.IListType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_qasm_core, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_qasm_core) + fold_pass(constructed_method) + + frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + address_wire_parent_qubit_0 = [] + address_wire_parent_qubit_1 = [] + address_types = frame.entries.values() # dict[SSAValue, Address] + for address_type in address_types: + if isinstance(address_type, address.AddressWire): + if address_type.origin_qubit.data == 0: + address_wire_parent_qubit_0.append(address_type) + elif address_type.origin_qubit.data == 1: + address_wire_parent_qubit_1.append(address_type) + + # there should be 3 AddressWire instances with parent qubit 0 + # and 2 AddressWire instances with parent qubit 1 + assert len(address_wire_parent_qubit_0) == 3 + assert len(address_wire_parent_qubit_1) == 2 + + +def test_multiple_wire_apply(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(2)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + # Put the wires through a 2Q operator + (op1 := squin.op.stmts.X()), + (op2 := squin.op.stmts.Control(op1.result, n_controls=1)), + (apply_stmt := squin.wire.Apply(op2.result, w0.result, w1.result)), + # Inside constant prop, in eval_statement in the forward data analysis, + # Apply is marked as pure so frame.get_values(SSAValues) -> ValueType (where) + (wire_list := ilist.New([apply_stmt.results[0], apply_stmt.results[1]])), + (func.Return(wire_list.result)), + ] + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=ilist.IListType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_qasm_core, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_qasm_core) + fold_pass(constructed_method) + + # const_prop = const.Propagate(squin_with_qasm_core) + # frame, _ = const_prop.run_analysis(method=constructed_method, no_raise=False) + + frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + address_wire_parent_qubit_0 = [] + address_wire_parent_qubit_1 = [] + address_types = frame.entries.values() # dict[SSAValue, Address] + for address_type in address_types: + if isinstance(address_type, address.AddressWire): + if address_type.origin_qubit.data == 0: + address_wire_parent_qubit_0.append(address_type) + elif address_type.origin_qubit.data == 1: + address_wire_parent_qubit_1.append(address_type) + + # Should be 2 AddressWire instances with origin qubit 0 + # and another 2 with origin qubit 1 + assert len(address_wire_parent_qubit_0) == 2 + assert len(address_wire_parent_qubit_1) == 2