From 51f27b50492419132e2d8b6413d299de1f85c361 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 28 Apr 2025 10:25:50 -0400 Subject: [PATCH 1/5] adding unit tests --- src/bloqade/squin/qubit.py | 36 +++++++++++++++++++++++++++----- test/squin/test_measure_sugar.py | 26 +++++++++++++++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 test/squin/test_measure_sugar.py diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index b18fc710..5276ff52 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -7,7 +7,7 @@ - `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits. """ -from typing import Any +from typing import Any, overload from kirin import ir, types, lowering from kirin.decl import info, statement @@ -42,7 +42,27 @@ class Broadcast(ir.Statement): @statement(dialect=dialect) -class Measure(ir.Statement): +class MeasureAny(ir.Statement): + name = "measure" + + traits = frozenset({lowering.FromPythonCall()}) + inputs: ir.SSAValue = info.argument(types.Any) + result: ir.ResultValue = info.result(types.Any) + + +@statement(dialect=dialect) +class MeasureQubit(ir.Statement): + name = "measure.qubit" + + traits = frozenset({lowering.FromPythonCall()}) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) + result: ir.ResultValue = info.result(ilist.IListType[types.Bool]) + + +@statement(dialect=dialect) +class MeasureReg(ir.Statement): + name = "measure.reg" + traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) result: ir.ResultValue = info.result(ilist.IListType[types.Bool]) @@ -89,9 +109,15 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None: ... -@wraps(Measure) -def measure(qubits: ilist.IList[Qubit, Any]) -> int: - """Measure the qubits in the list." +@overload +def measure(qubit: Qubit) -> bool: ... +@overload +def measure(qubit: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ... + + +@wraps(MeasureAny) +def measure(qubit: Any) -> Any: + """Measure a qubit or qubits in the list." Args: qubits: The list of qubits to measure. diff --git a/test/squin/test_measure_sugar.py b/test/squin/test_measure_sugar.py new file mode 100644 index 00000000..db270c78 --- /dev/null +++ b/test/squin/test_measure_sugar.py @@ -0,0 +1,26 @@ +from bloqade import squin + + +def test_measure_register(): + @squin.kernel + def test_measure_sugar(): + q = squin.qubit.new(2) + + return squin.qubit.measure(q) + + assert isinstance( + test_measure_sugar.callable_region.blocks[-1].last_stmt, squin.qubit.MeasureReg + ) + + +def test_measure_qubit(): + @squin.kernel + def test_measure_sugar(): + q = squin.qubit.new(2) + + return squin.qubit.measure(q[0]) + + assert isinstance( + test_measure_sugar.callable_region.blocks[-1].last_stmt, + squin.qubit.MeasureQubit, + ) From 39058e2064432df70ce6aad453eccaded4c46f60 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 28 Apr 2025 11:30:57 -0400 Subject: [PATCH 2/5] fixing test --- test/squin/test_measure_sugar.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/squin/test_measure_sugar.py b/test/squin/test_measure_sugar.py index db270c78..7c34b170 100644 --- a/test/squin/test_measure_sugar.py +++ b/test/squin/test_measure_sugar.py @@ -1,6 +1,16 @@ +from kirin import ir +from kirin.dialects import func + from bloqade import squin +def get_return_value_stmt(kernel: ir.Method): + assert isinstance( + last_stmt := kernel.callable_region.blocks[-1].last_stmt, func.Return + ) + return last_stmt.value.owner + + def test_measure_register(): @squin.kernel def test_measure_sugar(): @@ -8,9 +18,7 @@ def test_measure_sugar(): return squin.qubit.measure(q) - assert isinstance( - test_measure_sugar.callable_region.blocks[-1].last_stmt, squin.qubit.MeasureReg - ) + assert isinstance(get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureReg) def test_measure_qubit(): @@ -21,6 +29,6 @@ def test_measure_sugar(): return squin.qubit.measure(q[0]) assert isinstance( - test_measure_sugar.callable_region.blocks[-1].last_stmt, + get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureQubit, ) From d5bbc6ad0a0daf7d959d3619d0cbade67cc87b1a Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 29 Apr 2025 11:47:09 -0400 Subject: [PATCH 3/5] adding rewrite pass to desugar kernel --- src/bloqade/squin/groups.py | 4 +++ src/bloqade/squin/qubit.py | 16 +++++----- src/bloqade/squin/rewrite/__init__.py | 0 src/bloqade/squin/rewrite/measure_desugar.py | 31 ++++++++++++++++++++ 4 files changed, 44 insertions(+), 7 deletions(-) create mode 100644 src/bloqade/squin/rewrite/__init__.py create mode 100644 src/bloqade/squin/rewrite/measure_desugar.py diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index 2a06e302..dda7a5b1 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -1,8 +1,10 @@ from kirin import ir, passes from kirin.prelude import structural_no_opt from kirin.dialects import ilist +from kirin.rewrite.walk import Walk from . import op, wire, qubit +from .rewrite.measure_desugar import MeasureDesugarRule @ir.dialect_group(structural_no_opt.union([op, qubit])) @@ -10,6 +12,7 @@ def kernel(self): fold_pass = passes.Fold(self) typeinfer_pass = passes.TypeInfer(self) ilist_desugar_pass = ilist.IListDesugar(self) + measure_desugar_pass = Walk(MeasureDesugarRule()) def run_pass(method: ir.Method, *, fold=True, typeinfer=True): method.verify() @@ -18,6 +21,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): if typeinfer: typeinfer_pass(method) + measure_desugar_pass.rewrite(method.code) ilist_desugar_pass(method) if typeinfer: typeinfer_pass(method) # fix types after desugaring diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 941d7566..dbdb0399 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -46,7 +46,7 @@ class MeasureAny(ir.Statement): name = "measure" traits = frozenset({lowering.FromPythonCall()}) - inputs: ir.SSAValue = info.argument(types.Any) + input: ir.SSAValue = info.argument(types.Any) result: ir.ResultValue = info.result(types.Any) @@ -55,7 +55,7 @@ class MeasureQubit(ir.Statement): name = "measure.qubit" traits = frozenset({lowering.FromPythonCall()}) - qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) + qubit: ir.SSAValue = info.argument(ilist.IListType[QubitType]) result: ir.ResultValue = info.result(ilist.IListType[types.Bool]) @@ -110,20 +110,22 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None: @overload -def measure(qubit: Qubit) -> bool: ... +def measure(input: Qubit) -> bool: ... @overload -def measure(qubit: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ... +def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ... @wraps(MeasureAny) -def measure(qubit: Any) -> Any: +def measure(input: Any) -> Any: """Measure a qubit or qubits in the list. Args: - qubits: The list of qubits to measure. + input: A qubit or a list of qubits to measure. Returns: - int: The result of the measurement. + bool | list[bool]: The result of the measurement. If a single qubit is measured, + a single boolean is returned. If a list of qubits is measured, a list of booleans + is returned. """ ... diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/squin/rewrite/measure_desugar.py b/src/bloqade/squin/rewrite/measure_desugar.py new file mode 100644 index 00000000..20ab6909 --- /dev/null +++ b/src/bloqade/squin/rewrite/measure_desugar.py @@ -0,0 +1,31 @@ +from kirin import ir, types +from kirin.dialects import ilist +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.squin.qubit import QubitType, MeasureAny, MeasureReg, MeasureQubit + + +class MeasureDesugarRule(RewriteRule): + """ + Desugar measure operations in the circuit. + """ + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + if not isinstance(node, MeasureAny): + return RewriteResult() + + if node.input.type.is_subseteq(QubitType): + node.replace_by( + MeasureQubit( + qubit=node.input, + ) + ) + elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]): + node.replace_by( + MeasureReg( + qubits=node.input, + ) + ) + + return RewriteResult() From 8a5e0cacb243bad319c0088361b37010ddafab38 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 29 Apr 2025 11:50:41 -0400 Subject: [PATCH 4/5] adding appropriate return value --- src/bloqade/squin/rewrite/measure_desugar.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/bloqade/squin/rewrite/measure_desugar.py b/src/bloqade/squin/rewrite/measure_desugar.py index 20ab6909..8bdbc2ea 100644 --- a/src/bloqade/squin/rewrite/measure_desugar.py +++ b/src/bloqade/squin/rewrite/measure_desugar.py @@ -21,11 +21,13 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: qubit=node.input, ) ) + return RewriteResult(has_done_something=True) elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]): node.replace_by( MeasureReg( qubits=node.input, ) ) + return RewriteResult(has_done_something=True) return RewriteResult() From 871499887dbbe8368c158dfe09b1387f0431919e Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 29 Apr 2025 12:06:13 -0400 Subject: [PATCH 5/5] updating names --- src/bloqade/squin/qubit.py | 4 ++-- src/bloqade/squin/rewrite/measure_desugar.py | 4 ++-- test/squin/test_measure_sugar.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index dbdb0399..6d6c53c6 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -60,8 +60,8 @@ class MeasureQubit(ir.Statement): @statement(dialect=dialect) -class MeasureReg(ir.Statement): - name = "measure.reg" +class MeasureQubitList(ir.Statement): + name = "measure.qubit.list" traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) diff --git a/src/bloqade/squin/rewrite/measure_desugar.py b/src/bloqade/squin/rewrite/measure_desugar.py index 8bdbc2ea..a6b0f095 100644 --- a/src/bloqade/squin/rewrite/measure_desugar.py +++ b/src/bloqade/squin/rewrite/measure_desugar.py @@ -2,7 +2,7 @@ from kirin.dialects import ilist from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin.qubit import QubitType, MeasureAny, MeasureReg, MeasureQubit +from bloqade.squin.qubit import QubitType, MeasureAny, MeasureQubit, MeasureQubitList class MeasureDesugarRule(RewriteRule): @@ -24,7 +24,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return RewriteResult(has_done_something=True) elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]): node.replace_by( - MeasureReg( + MeasureQubitList( qubits=node.input, ) ) diff --git a/test/squin/test_measure_sugar.py b/test/squin/test_measure_sugar.py index 7c34b170..36ec617f 100644 --- a/test/squin/test_measure_sugar.py +++ b/test/squin/test_measure_sugar.py @@ -18,7 +18,9 @@ def test_measure_sugar(): return squin.qubit.measure(q) - assert isinstance(get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureReg) + assert isinstance( + get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureQubitList + ) def test_measure_qubit():