Skip to content

Commit 17db0d4

Browse files
authored
Merge branch 'main' into john/add-broadcast-to-wire
2 parents b2f1238 + 5a892c4 commit 17db0d4

File tree

12 files changed

+176
-52
lines changed

12 files changed

+176
-52
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
steps:
2222
- uses: actions/checkout@v4
2323
- name: Install uv
24-
uses: astral-sh/setup-uv@v5
24+
uses: astral-sh/setup-uv@v6
2525
with:
2626
# Install a specific version of uv.
2727
version: "0.5.1"

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
steps:
1111
- uses: actions/checkout@v4
1212
- name: Install uv
13-
uses: astral-sh/setup-uv@v5
13+
uses: astral-sh/setup-uv@v6
1414
with:
1515
# Install a specific version of uv.
1616
version: "0.5.5"

src/bloqade/pyqrack/qasm2/uop.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from kirin import interp
44

5+
from pyqrack.pauli import Pauli
56
from bloqade.pyqrack.reg import PyQrackQubit
67
from bloqade.qasm2.dialects import uop
78

@@ -26,7 +27,14 @@ class PyQrackMethods(interp.MethodTable):
2627
"tdg": "adjt",
2728
}
2829

29-
AXIS_MAP = {"rx": 1, "ry": 2, "rz": 3, "crx": 1, "cry": 2, "crz": 3}
30+
AXIS_MAP = {
31+
"rx": Pauli.PauliX,
32+
"ry": Pauli.PauliY,
33+
"rz": Pauli.PauliZ,
34+
"crx": Pauli.PauliX,
35+
"cry": Pauli.PauliY,
36+
"crz": Pauli.PauliZ,
37+
}
3038

3139
@interp.impl(uop.Barrier)
3240
def barrier(

src/bloqade/qasm2/passes/noise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir
4-
from kirin.passes import Pass
4+
from kirin.passes import Pass, HintConst
55
from kirin.rewrite import (
66
Walk,
77
Chain,
@@ -10,11 +10,10 @@
1010
DeadCodeElimination,
1111
CommonSubexpressionElimination,
1212
)
13-
from kirin.rewrite.abc import RewriteResult
1413

1514
from bloqade.noise import native
1615
from bloqade.analysis import address
17-
from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
16+
from bloqade.qasm2.rewrite.heuristic_noise import InsertGetQubit, NoiseRewriteRule
1817

1918

2019
@dataclass
@@ -38,16 +37,17 @@ def __post_init__(self):
3837
self.address_analysis = address.AddressAnalysis(self.dialects)
3938

4039
def unsafe_run(self, mt: ir.Method):
41-
result = RewriteResult()
42-
43-
frame, res = self.address_analysis.run_analysis(mt, no_raise=False)
40+
result = Walk(InsertGetQubit()).rewrite(mt.code)
41+
HintConst(self.dialects).unsafe_run(mt)
42+
frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
4443
result = (
4544
Walk(
4645
NoiseRewriteRule(
4746
address_analysis=frame.entries,
4847
noise_model=self.noise_model,
4948
gate_noise_params=self.gate_noise_params,
50-
)
49+
),
50+
reverse=True,
5151
)
5252
.rewrite(mt.code)
5353
.join(result)

src/bloqade/qasm2/rewrite/heuristic_noise.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,34 @@
1010
from bloqade.qasm2.dialects import uop, core, glob, parallel
1111

1212

13+
class InsertGetQubit(rewrite_abc.RewriteRule):
14+
15+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
16+
if (
17+
not isinstance(node, core.QRegNew)
18+
or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant)
19+
or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int)
20+
or (block := node.parent_block) is None
21+
):
22+
return rewrite_abc.RewriteResult()
23+
24+
n_qubits_stmt.detach()
25+
node.detach()
26+
if block.first_stmt is None:
27+
block.stmts.append(node)
28+
else:
29+
node.insert_before(block.first_stmt)
30+
n_qubits_stmt.insert_before(block.first_stmt)
31+
32+
for idx_val in range(n_qubits):
33+
idx = py.constant.Constant(value=idx_val)
34+
qubit = core.QRegGet(node.result, idx=idx.result)
35+
qubit.insert_after(node)
36+
idx.insert_after(node)
37+
38+
return rewrite_abc.RewriteResult(has_done_something=True)
39+
40+
1341
@dataclass
1442
class NoiseRewriteRule(rewrite_abc.RewriteRule):
1543
"""
@@ -24,12 +52,18 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
2452
noise_model: native.MoveNoiseModelABC = field(
2553
default_factory=native.TwoRowZoneModel
2654
)
27-
qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
55+
56+
def __post_init__(self):
57+
self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
58+
for ssa_value, addr in self.address_analysis.items():
59+
if (
60+
isinstance(addr, address.AddressQubit)
61+
and ssa_value not in self.qubit_ssa_value
62+
):
63+
self.qubit_ssa_value[addr.data] = ssa_value
2864

2965
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
30-
if isinstance(node, core.QRegNew):
31-
return self.rewrite_qreg_new(node)
32-
elif isinstance(node, uop.SingleQubitGate):
66+
if isinstance(node, uop.SingleQubitGate):
3367
return self.rewrite_single_qubit_gate(node)
3468
elif isinstance(node, uop.CZ):
3569
return self.rewrite_cz_gate(node)
@@ -42,24 +76,6 @@ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
4276
else:
4377
return rewrite_abc.RewriteResult()
4478

45-
def rewrite_qreg_new(self, node: core.QRegNew):
46-
47-
addr = self.address_analysis[node.result]
48-
if not isinstance(addr, address.AddressReg):
49-
return rewrite_abc.RewriteResult()
50-
51-
has_done_something = False
52-
for idx_val, qid in enumerate(addr.data):
53-
if qid not in self.qubit_ssa_value:
54-
has_done_something = True
55-
idx = py.constant.Constant(value=idx_val)
56-
qubit = core.QRegGet(node.result, idx=idx.result)
57-
self.qubit_ssa_value[qid] = qubit.result
58-
qubit.insert_after(node)
59-
idx.insert_after(node)
60-
61-
return rewrite_abc.RewriteResult(has_done_something=has_done_something)
62-
6379
def insert_single_qubit_noise(
6480
self,
6581
node: ir.Statement,

src/bloqade/qasm2/rewrite/uop_to_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __call__(self, node: ir.Statement) -> RewriteResult:
154154
self.group_has_merged[group_number] = result.has_done_something
155155
return result
156156

157-
if self.group_has_merged[group_number]:
157+
if self.group_has_merged.setdefault(group_number, False):
158158
node.delete()
159159

160160
return RewriteResult(has_done_something=self.group_has_merged[group_number])

src/bloqade/squin/groups.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from kirin import ir, passes
22
from kirin.prelude import structural_no_opt
33
from kirin.dialects import ilist
4+
from kirin.rewrite.walk import Walk
45

56
from . import op, wire, qubit
7+
from .rewrite.measure_desugar import MeasureDesugarRule
68

79

810
@ir.dialect_group(structural_no_opt.union([op, qubit]))
911
def kernel(self):
1012
fold_pass = passes.Fold(self)
1113
typeinfer_pass = passes.TypeInfer(self)
1214
ilist_desugar_pass = ilist.IListDesugar(self)
15+
measure_desugar_pass = Walk(MeasureDesugarRule())
1316

1417
def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1518
method.verify()
@@ -18,6 +21,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1821

1922
if typeinfer:
2023
typeinfer_pass(method)
24+
measure_desugar_pass.rewrite(method.code)
2125
ilist_desugar_pass(method)
2226
if typeinfer:
2327
typeinfer_pass(method) # fix types after desugaring

src/bloqade/squin/qubit.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits.
88
"""
99

10-
from typing import Any
10+
from typing import Any, overload
1111

1212
from kirin import ir, types, lowering
1313
from kirin.decl import info, statement
@@ -41,8 +41,26 @@ class New(ir.Statement):
4141
result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any])
4242

4343

44+
class MeasureAny(ir.Statement):
45+
name = "measure"
46+
47+
traits = frozenset({lowering.FromPythonCall()})
48+
input: ir.SSAValue = info.argument(types.Any)
49+
result: ir.ResultValue = info.result(types.Any)
50+
51+
4452
@statement(dialect=dialect)
45-
class Measure(ir.Statement):
53+
class MeasureQubit(ir.Statement):
54+
name = "measure.qubit"
55+
56+
traits = frozenset({lowering.FromPythonCall()})
57+
qubit: ir.SSAValue = info.argument(ilist.IListType[QubitType])
58+
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
59+
60+
61+
@statement(dialect=dialect)
62+
class MeasureQubitList(ir.Statement):
63+
name = "measure.qubit.list"
4664
traits = frozenset({lowering.FromPythonCall()})
4765
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
4866
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
@@ -89,31 +107,39 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
89107
...
90108

91109

92-
@wraps(Broadcast)
93-
def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
94-
"""Broadcast and apply an operator to a list of qubits. For example, an operator
95-
that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.
110+
@overload
111+
def measure(input: Qubit) -> bool: ...
112+
@overload
113+
def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ...
114+
115+
116+
@wraps(MeasureAny)
117+
def measure(input: Any) -> Any:
118+
"""Measure a qubit or qubits in the list.
96119
97120
Args:
98-
operator: The operator to broadcast and apply.
99-
qubits: The list of qubits to broadcast and apply the operator to. The size of the list
100-
must be inferable and match the number of qubits expected by the operator.
121+
input: A qubit or a list of qubits to measure.
101122
102123
Returns:
103-
None
124+
bool | list[bool]: The result of the measurement. If a single qubit is measured,
125+
a single boolean is returned. If a list of qubits is measured, a list of booleans
126+
is returned.
104127
"""
105128
...
106129

107130

108-
@wraps(Measure)
109-
def measure(qubits: ilist.IList[Qubit, Any]) -> int:
110-
"""Measure the qubits in the list."
131+
@wraps(Broadcast)
132+
def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
133+
"""Broadcast and apply an operator to a list of qubits. For example, an operator
134+
that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.
111135
112136
Args:
113-
qubits: The list of qubits to measure.
137+
operator: The operator to broadcast and apply.
138+
qubits: The list of qubits to broadcast and apply the operator to. The size of the list
139+
must be inferable and match the number of qubits expected by the operator.
114140
115141
Returns:
116-
int: The result of the measurement.
142+
None
117143
"""
118144
...
119145

src/bloqade/squin/rewrite/__init__.py

Whitespace-only changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from kirin import ir, types
2+
from kirin.dialects import ilist
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from bloqade.squin.qubit import QubitType, MeasureAny, MeasureQubit, MeasureQubitList
6+
7+
8+
class MeasureDesugarRule(RewriteRule):
9+
"""
10+
Desugar measure operations in the circuit.
11+
"""
12+
13+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14+
15+
if not isinstance(node, MeasureAny):
16+
return RewriteResult()
17+
18+
if node.input.type.is_subseteq(QubitType):
19+
node.replace_by(
20+
MeasureQubit(
21+
qubit=node.input,
22+
)
23+
)
24+
return RewriteResult(has_done_something=True)
25+
elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
26+
node.replace_by(
27+
MeasureQubitList(
28+
qubits=node.input,
29+
)
30+
)
31+
return RewriteResult(has_done_something=True)
32+
33+
return RewriteResult()

0 commit comments

Comments
 (0)