Skip to content

Commit f039789

Browse files
authored
Measure Refactor. (#213)
This PR addresses #158
1 parent 009b73d commit f039789

File tree

5 files changed

+116
-15
lines changed

5 files changed

+116
-15
lines changed

src/bloqade/squin/groups.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from kirin import ir, passes
22
from kirin.prelude import structural_no_opt
33
from kirin.dialects import ilist
4+
from kirin.rewrite.walk import Walk
45

56
from . import op, wire, qubit
7+
from .rewrite.measure_desugar import MeasureDesugarRule
68

79

810
@ir.dialect_group(structural_no_opt.union([op, qubit]))
911
def kernel(self):
1012
fold_pass = passes.Fold(self)
1113
typeinfer_pass = passes.TypeInfer(self)
1214
ilist_desugar_pass = ilist.IListDesugar(self)
15+
measure_desugar_pass = Walk(MeasureDesugarRule())
1316

1417
def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1518
method.verify()
@@ -18,6 +21,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1821

1922
if typeinfer:
2023
typeinfer_pass(method)
24+
measure_desugar_pass.rewrite(method.code)
2125
ilist_desugar_pass(method)
2226
if typeinfer:
2327
typeinfer_pass(method) # fix types after desugaring

src/bloqade/squin/qubit.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits.
88
"""
99

10-
from typing import Any
10+
from typing import Any, overload
1111

1212
from kirin import ir, types, lowering
1313
from kirin.decl import info, statement
@@ -42,7 +42,27 @@ class Broadcast(ir.Statement):
4242

4343

4444
@statement(dialect=dialect)
45-
class Measure(ir.Statement):
45+
class MeasureAny(ir.Statement):
46+
name = "measure"
47+
48+
traits = frozenset({lowering.FromPythonCall()})
49+
input: ir.SSAValue = info.argument(types.Any)
50+
result: ir.ResultValue = info.result(types.Any)
51+
52+
53+
@statement(dialect=dialect)
54+
class MeasureQubit(ir.Statement):
55+
name = "measure.qubit"
56+
57+
traits = frozenset({lowering.FromPythonCall()})
58+
qubit: ir.SSAValue = info.argument(ilist.IListType[QubitType])
59+
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
60+
61+
62+
@statement(dialect=dialect)
63+
class MeasureQubitList(ir.Statement):
64+
name = "measure.qubit.list"
65+
4666
traits = frozenset({lowering.FromPythonCall()})
4767
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
4868
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
@@ -89,31 +109,39 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
89109
...
90110

91111

92-
@wraps(Broadcast)
93-
def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
94-
"""Broadcast and apply an operator to a list of qubits. For example, an operator
95-
that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.
112+
@overload
113+
def measure(input: Qubit) -> bool: ...
114+
@overload
115+
def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ...
116+
117+
118+
@wraps(MeasureAny)
119+
def measure(input: Any) -> Any:
120+
"""Measure a qubit or qubits in the list.
96121
97122
Args:
98-
operator: The operator to broadcast and apply.
99-
qubits: The list of qubits to broadcast and apply the operator to. The size of the list
100-
must be inferable and match the number of qubits expected by the operator.
123+
input: A qubit or a list of qubits to measure.
101124
102125
Returns:
103-
None
126+
bool | list[bool]: The result of the measurement. If a single qubit is measured,
127+
a single boolean is returned. If a list of qubits is measured, a list of booleans
128+
is returned.
104129
"""
105130
...
106131

107132

108-
@wraps(Measure)
109-
def measure(qubits: ilist.IList[Qubit, Any]) -> int:
110-
"""Measure the qubits in the list."
133+
@wraps(Broadcast)
134+
def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
135+
"""Broadcast and apply an operator to a list of qubits. For example, an operator
136+
that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.
111137
112138
Args:
113-
qubits: The list of qubits to measure.
139+
operator: The operator to broadcast and apply.
140+
qubits: The list of qubits to broadcast and apply the operator to. The size of the list
141+
must be inferable and match the number of qubits expected by the operator.
114142
115143
Returns:
116-
int: The result of the measurement.
144+
None
117145
"""
118146
...
119147

src/bloqade/squin/rewrite/__init__.py

Whitespace-only changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from kirin import ir, types
2+
from kirin.dialects import ilist
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from bloqade.squin.qubit import QubitType, MeasureAny, MeasureQubit, MeasureQubitList
6+
7+
8+
class MeasureDesugarRule(RewriteRule):
9+
"""
10+
Desugar measure operations in the circuit.
11+
"""
12+
13+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14+
15+
if not isinstance(node, MeasureAny):
16+
return RewriteResult()
17+
18+
if node.input.type.is_subseteq(QubitType):
19+
node.replace_by(
20+
MeasureQubit(
21+
qubit=node.input,
22+
)
23+
)
24+
return RewriteResult(has_done_something=True)
25+
elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
26+
node.replace_by(
27+
MeasureQubitList(
28+
qubits=node.input,
29+
)
30+
)
31+
return RewriteResult(has_done_something=True)
32+
33+
return RewriteResult()

test/squin/test_measure_sugar.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from kirin import ir
2+
from kirin.dialects import func
3+
4+
from bloqade import squin
5+
6+
7+
def get_return_value_stmt(kernel: ir.Method):
8+
assert isinstance(
9+
last_stmt := kernel.callable_region.blocks[-1].last_stmt, func.Return
10+
)
11+
return last_stmt.value.owner
12+
13+
14+
def test_measure_register():
15+
@squin.kernel
16+
def test_measure_sugar():
17+
q = squin.qubit.new(2)
18+
19+
return squin.qubit.measure(q)
20+
21+
assert isinstance(
22+
get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureQubitList
23+
)
24+
25+
26+
def test_measure_qubit():
27+
@squin.kernel
28+
def test_measure_sugar():
29+
q = squin.qubit.new(2)
30+
31+
return squin.qubit.measure(q[0])
32+
33+
assert isinstance(
34+
get_return_value_stmt(test_measure_sugar),
35+
squin.qubit.MeasureQubit,
36+
)

0 commit comments

Comments
 (0)