Skip to content

Commit 009b73d

Browse files
authored
Fixing bugs related to kirin 0.17.3 version. (#214)
fixes: * ParallelToUop I just added a `setdefault` so that it would register the dictionary value and then overwrite it later * For RewriteNoiseModel: I split up the rewrites to be correct when having the new traversal order.
1 parent 4dbe384 commit 009b73d

File tree

3 files changed

+46
-30
lines changed

3 files changed

+46
-30
lines changed

src/bloqade/qasm2/passes/noise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir
4-
from kirin.passes import Pass
4+
from kirin.passes import Pass, HintConst
55
from kirin.rewrite import (
66
Walk,
77
Chain,
@@ -10,11 +10,10 @@
1010
DeadCodeElimination,
1111
CommonSubexpressionElimination,
1212
)
13-
from kirin.rewrite.abc import RewriteResult
1413

1514
from bloqade.noise import native
1615
from bloqade.analysis import address
17-
from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule
16+
from bloqade.qasm2.rewrite.heuristic_noise import InsertGetQubit, NoiseRewriteRule
1817

1918

2019
@dataclass
@@ -38,16 +37,17 @@ def __post_init__(self):
3837
self.address_analysis = address.AddressAnalysis(self.dialects)
3938

4039
def unsafe_run(self, mt: ir.Method):
41-
result = RewriteResult()
42-
43-
frame, res = self.address_analysis.run_analysis(mt, no_raise=False)
40+
result = Walk(InsertGetQubit()).rewrite(mt.code)
41+
HintConst(self.dialects).unsafe_run(mt)
42+
frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise)
4443
result = (
4544
Walk(
4645
NoiseRewriteRule(
4746
address_analysis=frame.entries,
4847
noise_model=self.noise_model,
4948
gate_noise_params=self.gate_noise_params,
50-
)
49+
),
50+
reverse=True,
5151
)
5252
.rewrite(mt.code)
5353
.join(result)

src/bloqade/qasm2/rewrite/heuristic_noise.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,34 @@
1010
from bloqade.qasm2.dialects import uop, core, glob, parallel
1111

1212

13+
class InsertGetQubit(rewrite_abc.RewriteRule):
14+
15+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
16+
if (
17+
not isinstance(node, core.QRegNew)
18+
or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant)
19+
or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int)
20+
or (block := node.parent_block) is None
21+
):
22+
return rewrite_abc.RewriteResult()
23+
24+
n_qubits_stmt.detach()
25+
node.detach()
26+
if block.first_stmt is None:
27+
block.stmts.append(node)
28+
else:
29+
node.insert_before(block.first_stmt)
30+
n_qubits_stmt.insert_before(block.first_stmt)
31+
32+
for idx_val in range(n_qubits):
33+
idx = py.constant.Constant(value=idx_val)
34+
qubit = core.QRegGet(node.result, idx=idx.result)
35+
qubit.insert_after(node)
36+
idx.insert_after(node)
37+
38+
return rewrite_abc.RewriteResult(has_done_something=True)
39+
40+
1341
@dataclass
1442
class NoiseRewriteRule(rewrite_abc.RewriteRule):
1543
"""
@@ -24,12 +52,18 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule):
2452
noise_model: native.MoveNoiseModelABC = field(
2553
default_factory=native.TwoRowZoneModel
2654
)
27-
qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
55+
56+
def __post_init__(self):
57+
self.qubit_ssa_value: Dict[int, ir.SSAValue] = {}
58+
for ssa_value, addr in self.address_analysis.items():
59+
if (
60+
isinstance(addr, address.AddressQubit)
61+
and ssa_value not in self.qubit_ssa_value
62+
):
63+
self.qubit_ssa_value[addr.data] = ssa_value
2864

2965
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
30-
if isinstance(node, core.QRegNew):
31-
return self.rewrite_qreg_new(node)
32-
elif isinstance(node, uop.SingleQubitGate):
66+
if isinstance(node, uop.SingleQubitGate):
3367
return self.rewrite_single_qubit_gate(node)
3468
elif isinstance(node, uop.CZ):
3569
return self.rewrite_cz_gate(node)
@@ -42,24 +76,6 @@ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
4276
else:
4377
return rewrite_abc.RewriteResult()
4478

45-
def rewrite_qreg_new(self, node: core.QRegNew):
46-
47-
addr = self.address_analysis[node.result]
48-
if not isinstance(addr, address.AddressReg):
49-
return rewrite_abc.RewriteResult()
50-
51-
has_done_something = False
52-
for idx_val, qid in enumerate(addr.data):
53-
if qid not in self.qubit_ssa_value:
54-
has_done_something = True
55-
idx = py.constant.Constant(value=idx_val)
56-
qubit = core.QRegGet(node.result, idx=idx.result)
57-
self.qubit_ssa_value[qid] = qubit.result
58-
qubit.insert_after(node)
59-
idx.insert_after(node)
60-
61-
return rewrite_abc.RewriteResult(has_done_something=has_done_something)
62-
6379
def insert_single_qubit_noise(
6480
self,
6581
node: ir.Statement,

src/bloqade/qasm2/rewrite/uop_to_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __call__(self, node: ir.Statement) -> RewriteResult:
154154
self.group_has_merged[group_number] = result.has_done_something
155155
return result
156156

157-
if self.group_has_merged[group_number]:
157+
if self.group_has_merged.setdefault(group_number, False):
158158
node.delete()
159159

160160
return RewriteResult(has_done_something=self.group_has_merged[group_number])

0 commit comments

Comments
 (0)