Skip to content

Commit 893a9e9

Browse files
weinbe58Roger-luo
andauthored
Fixing NoisePass with LiftQubits Pass. (#230)
in issue #226 The problem is because how the qubit -> global index. In this PR I add an explicit pass that lifts all qubit references to the top of the entry block. --------- Co-authored-by: Xiu-zhe (Roger) Luo <[email protected]>
1 parent 09f9214 commit 893a9e9

File tree

4 files changed

+67
-43
lines changed

4 files changed

+67
-43
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from kirin import ir
2+
from kirin.passes import Pass
3+
from kirin.rewrite import (
4+
Walk,
5+
Chain,
6+
Fixpoint,
7+
ConstantFold,
8+
CommonSubexpressionElimination,
9+
)
10+
from kirin.passes.hint_const import HintConst
11+
12+
from bloqade.qasm2.rewrite.insert_qubits import InsertGetQubit
13+
14+
15+
class LiftQubits(Pass):
16+
"""This pass lifts the creation of qubits to the block where the register is defined."""
17+
18+
def unsafe_run(self, mt: ir.Method):
19+
result = Walk(InsertGetQubit()).rewrite(mt.code)
20+
result = HintConst(self.dialects).unsafe_run(mt).join(result)
21+
result = (
22+
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))
23+
.rewrite(mt.code)
24+
.join(result)
25+
)
26+
return result

src/bloqade/qasm2/passes/noise.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir
4-
from kirin.passes import Pass, HintConst
4+
from kirin.passes import Pass
55
from kirin.rewrite import (
66
Walk,
7-
Chain,
87
Fixpoint,
9-
ConstantFold,
108
DeadCodeElimination,
11-
CommonSubexpressionElimination,
129
)
1310

1411
from bloqade.noise import native
1512
from bloqade.analysis import address
16-
from bloqade.qasm2.rewrite.heuristic_noise import InsertGetQubit, NoiseRewriteRule
13+
from bloqade.qasm2.passes.lift_qubits import LiftQubits
14+
from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
1715

1816

1917
@dataclass
@@ -37,8 +35,7 @@ def __post_init__(self):
3735
self.address_analysis = address.AddressAnalysis(self.dialects)
3836

3937
def unsafe_run(self, mt: ir.Method):
40-
result = Walk(InsertGetQubit()).rewrite(mt.code)
41-
HintConst(self.dialects).unsafe_run(mt)
38+
result = LiftQubits(self.dialects).unsafe_run(mt)
4239
frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
4340
result = (
4441
Walk(
@@ -52,10 +49,5 @@ def unsafe_run(self, mt: ir.Method):
5249
.rewrite(mt.code)
5350
.join(result)
5451
)
55-
rule = Chain(
56-
ConstantFold(),
57-
DeadCodeElimination(),
58-
CommonSubexpressionElimination(),
59-
)
60-
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
52+
result = Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
6153
return result

src/bloqade/qasm2/rewrite/heuristic_noise.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,11 @@
33

44
from kirin import ir
55
from kirin.rewrite import abc as rewrite_abc
6-
from kirin.dialects import py, ilist
6+
from kirin.dialects import ilist
77

88
from bloqade.noise import native
99
from bloqade.analysis import address
10-
from bloqade.qasm2.dialects import uop, core, glob, parallel
11-
12-
13-
class InsertGetQubit(rewrite_abc.RewriteRule):
14-
15-
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
16-
if (
17-
not isinstance(node, core.QRegNew)
18-
or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant)
19-
or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int)
20-
or (block := node.parent_block) is None
21-
):
22-
return rewrite_abc.RewriteResult()
23-
24-
n_qubits_stmt.detach()
25-
node.detach()
26-
if block.first_stmt is None:
27-
block.stmts.append(node)
28-
else:
29-
node.insert_before(block.first_stmt)
30-
n_qubits_stmt.insert_before(block.first_stmt)
31-
32-
for idx_val in range(n_qubits):
33-
idx = py.constant.Constant(value=idx_val)
34-
qubit = core.QRegGet(node.result, idx=idx.result)
35-
qubit.insert_after(node)
36-
idx.insert_after(node)
37-
38-
return rewrite_abc.RewriteResult(has_done_something=True)
10+
from bloqade.qasm2.dialects import uop, glob, parallel
3911

4012

4113
@dataclass
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc as rewrite_abc
3+
from kirin.dialects import py
4+
5+
6+
class InsertGetQubit(rewrite_abc.RewriteRule):
7+
8+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
9+
from bloqade.qasm2 import core
10+
11+
if (
12+
not isinstance(node, core.QRegNew)
13+
or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant)
14+
or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int)
15+
or (block := node.parent_block) is None
16+
):
17+
return rewrite_abc.RewriteResult()
18+
19+
n_qubits_stmt.detach()
20+
node.detach()
21+
if block.first_stmt is None:
22+
block.stmts.append(n_qubits_stmt)
23+
block.stmts.append(node)
24+
else:
25+
node.insert_before(block.first_stmt)
26+
n_qubits_stmt.insert_before(block.first_stmt)
27+
28+
for idx_val in range(n_qubits):
29+
idx = py.constant.Constant(value=idx_val)
30+
qubit = core.QRegGet(node.result, idx=idx.result)
31+
qubit.insert_after(node)
32+
idx.insert_after(node)
33+
34+
return rewrite_abc.RewriteResult(has_done_something=True)

0 commit comments

Comments
 (0)