Skip to content

Commit 09f47f8

Browse files
committed
fixing noise pass to work with new kirin patch
1 parent 9b87004 commit 09f47f8

File tree

2 files changed

+46
-29
lines changed

2 files changed

+46
-29
lines changed

src/bloqade/qasm2/passes/noise.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir
4-
from kirin.passes import Pass
4+
from kirin.passes import Pass, HintConst
55
from kirin.rewrite import (
66
Walk,
77
Chain,
@@ -10,11 +10,10 @@
1010
DeadCodeElimination,
1111
CommonSubexpressionElimination,
1212
)
13-
from kirin.rewrite.abc import RewriteResult
1413

1514
from bloqade.noise import native
1615
from bloqade.analysis import address
17-
from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
16+
from bloqade.qasm2.rewrite.heuristic_noise import InsertGetQubit, NoiseRewriteRule
1817

1918

2019
@dataclass
@@ -38,16 +37,18 @@ def __post_init__(self):
3837
self.address_analysis = address.AddressAnalysis(self.dialects)
3938

4039
def unsafe_run(self, mt: ir.Method):
41-
result = RewriteResult()
42-
43-
frame, res = self.address_analysis.run_analysis(mt, no_raise=False)
40+
result = Walk(InsertGetQubit()).rewrite(mt.code)
41+
mt.print()
42+
HintConst(self.dialects).unsafe_run(mt)
43+
frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
4444
result = (
4545
Walk(
4646
NoiseRewriteRule(
4747
address_analysis=frame.entries,
4848
noise_model=self.noise_model,
4949
gate_noise_params=self.gate_noise_params,
50-
)
50+
),
51+
reverse=True,
5152
)
5253
.rewrite(mt.code)
5354
.join(result)

src/bloqade/qasm2/rewrite/heuristic_noise.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,34 @@
1010
from bloqade.qasm2.dialects import uop, core, glob, parallel
1111

1212

13+
class InsertGetQubit(rewrite_abc.RewriteRule):
14+
15+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
16+
if (
17+
not isinstance(node, core.QRegNew)
18+
or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant)
19+
or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int)
20+
or (block := node.parent_block) is None
21+
):
22+
return rewrite_abc.RewriteResult()
23+
24+
n_qubits_stmt.detach()
25+
node.detach()
26+
if block.first_stmt is None:
27+
block.stmts.append(node)
28+
else:
29+
node.insert_before(block.first_stmt)
30+
n_qubits_stmt.insert_before(block.first_stmt)
31+
32+
for idx_val in range(n_qubits):
33+
idx = py.constant.Constant(value=idx_val)
34+
qubit = core.QRegGet(node.result, idx=idx.result)
35+
qubit.insert_after(node)
36+
idx.insert_after(node)
37+
38+
return rewrite_abc.RewriteResult(has_done_something=True)
39+
40+
1341
@dataclass
1442
class NoiseRewriteRule(rewrite_abc.RewriteRule):
1543
"""
@@ -24,12 +52,18 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
2452
noise_model: native.MoveNoiseModelABC = field(
2553
default_factory=native.TwoRowZoneModel
2654
)
27-
qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
55+
56+
def __post_init__(self):
57+
self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
58+
for ssa_value, addr in self.address_analysis.items():
59+
if (
60+
isinstance(addr, address.AddressQubit)
61+
and ssa_value not in self.qubit_ssa_value
62+
):
63+
self.qubit_ssa_value[addr.data] = ssa_value
2864

2965
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
30-
if isinstance(node, core.QRegNew):
31-
return self.rewrite_qreg_new(node)
32-
elif isinstance(node, uop.SingleQubitGate):
66+
if isinstance(node, uop.SingleQubitGate):
3367
return self.rewrite_single_qubit_gate(node)
3468
elif isinstance(node, uop.CZ):
3569
return self.rewrite_cz_gate(node)
@@ -42,24 +76,6 @@ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
4276
else:
4377
return rewrite_abc.RewriteResult()
4478

45-
def rewrite_qreg_new(self, node: core.QRegNew):
46-
47-
addr = self.address_analysis[node.result]
48-
if not isinstance(addr, address.AddressReg):
49-
return rewrite_abc.RewriteResult()
50-
51-
has_done_something = False
52-
for idx_val, qid in enumerate(addr.data):
53-
if qid not in self.qubit_ssa_value:
54-
has_done_something = True
55-
idx = py.constant.Constant(value=idx_val)
56-
qubit = core.QRegGet(node.result, idx=idx.result)
57-
self.qubit_ssa_value[qid] = qubit.result
58-
qubit.insert_after(node)
59-
idx.insert_after(node)
60-
61-
return rewrite_abc.RewriteResult(has_done_something=has_done_something)
62-
6379
def insert_single_qubit_noise(
6480
self,
6581
node: ir.Statement,

0 commit comments

Comments
 (0)