diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index 2a06e302..dda7a5b1 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -1,8 +1,10 @@ from kirin import ir, passes from kirin.prelude import structural_no_opt from kirin.dialects import ilist +from kirin.rewrite.walk import Walk from . import op, wire, qubit +from .rewrite.measure_desugar import MeasureDesugarRule @ir.dialect_group(structural_no_opt.union([op, qubit])) @@ -10,6 +12,7 @@ def kernel(self): fold_pass = passes.Fold(self) typeinfer_pass = passes.TypeInfer(self) ilist_desugar_pass = ilist.IListDesugar(self) + measure_desugar_pass = Walk(MeasureDesugarRule()) def run_pass(method: ir.Method, *, fold=True, typeinfer=True): method.verify() @@ -18,6 +21,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): if typeinfer: typeinfer_pass(method) + measure_desugar_pass.rewrite(method.code) ilist_desugar_pass(method) if typeinfer: typeinfer_pass(method) # fix types after desugaring diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 91646a6d..6d6c53c6 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -7,7 +7,7 @@ - `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits. """ -from typing import Any +from typing import Any, overload from kirin import ir, types, lowering from kirin.decl import info, statement @@ -42,7 +42,27 @@ class Broadcast(ir.Statement): @statement(dialect=dialect) -class Measure(ir.Statement): +class MeasureAny(ir.Statement): + name = "measure" + + traits = frozenset({lowering.FromPythonCall()}) + input: ir.SSAValue = info.argument(types.Any) + result: ir.ResultValue = info.result(types.Any) + + +@statement(dialect=dialect) +class MeasureQubit(ir.Statement): + name = "measure.qubit" + + traits = frozenset({lowering.FromPythonCall()}) + qubit: ir.SSAValue = info.argument(ilist.IListType[QubitType]) + result: ir.ResultValue = info.result(ilist.IListType[types.Bool]) + + +@statement(dialect=dialect) +class MeasureQubitList(ir.Statement): + name = "measure.qubit.list" + traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) result: ir.ResultValue = info.result(ilist.IListType[types.Bool]) @@ -89,31 +109,39 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None: ... -@wraps(Broadcast) -def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None: - """Broadcast and apply an operator to a list of qubits. For example, an operator - that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0. +@overload +def measure(input: Qubit) -> bool: ... +@overload +def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ... + + +@wraps(MeasureAny) +def measure(input: Any) -> Any: + """Measure a qubit or qubits in the list. Args: - operator: The operator to broadcast and apply. - qubits: The list of qubits to broadcast and apply the operator to. The size of the list - must be inferable and match the number of qubits expected by the operator. + input: A qubit or a list of qubits to measure. Returns: - None + bool | list[bool]: The result of the measurement. If a single qubit is measured, + a single boolean is returned. If a list of qubits is measured, a list of booleans + is returned. """ ... -@wraps(Measure) -def measure(qubits: ilist.IList[Qubit, Any]) -> int: - """Measure the qubits in the list." +@wraps(Broadcast) +def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None: + """Broadcast and apply an operator to a list of qubits. For example, an operator + that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0. Args: - qubits: The list of qubits to measure. + operator: The operator to broadcast and apply. + qubits: The list of qubits to broadcast and apply the operator to. The size of the list + must be inferable and match the number of qubits expected by the operator. Returns: - int: The result of the measurement. + None """ ... diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/squin/rewrite/measure_desugar.py b/src/bloqade/squin/rewrite/measure_desugar.py new file mode 100644 index 00000000..a6b0f095 --- /dev/null +++ b/src/bloqade/squin/rewrite/measure_desugar.py @@ -0,0 +1,33 @@ +from kirin import ir, types +from kirin.dialects import ilist +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.squin.qubit import QubitType, MeasureAny, MeasureQubit, MeasureQubitList + + +class MeasureDesugarRule(RewriteRule): + """ + Desugar measure operations in the circuit. + """ + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if not isinstance(node, MeasureAny): + return RewriteResult() + + if node.input.type.is_subseteq(QubitType): + node.replace_by( + MeasureQubit( + qubit=node.input, + ) + ) + return RewriteResult(has_done_something=True) + elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]): + node.replace_by( + MeasureQubitList( + qubits=node.input, + ) + ) + return RewriteResult(has_done_something=True) + + return RewriteResult() diff --git a/test/squin/test_measure_sugar.py b/test/squin/test_measure_sugar.py new file mode 100644 index 00000000..36ec617f --- /dev/null +++ b/test/squin/test_measure_sugar.py @@ -0,0 +1,36 @@ +from kirin import ir +from kirin.dialects import func + +from bloqade import squin + + +def get_return_value_stmt(kernel: ir.Method): + assert isinstance( + last_stmt := kernel.callable_region.blocks[-1].last_stmt, func.Return + ) + return last_stmt.value.owner + + +def test_measure_register(): + @squin.kernel + def test_measure_sugar(): + q = squin.qubit.new(2) + + return squin.qubit.measure(q) + + assert isinstance( + get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureQubitList + ) + + +def test_measure_qubit(): + @squin.kernel + def test_measure_sugar(): + q = squin.qubit.new(2) + + return squin.qubit.measure(q[0]) + + assert isinstance( + get_return_value_stmt(test_measure_sugar), + squin.qubit.MeasureQubit, + )