diff --git a/src/bloqade/qasm2/passes/noise.py b/src/bloqade/qasm2/passes/noise.py index b91c1e51..477bd62a 100644 --- a/src/bloqade/qasm2/passes/noise.py +++ b/src/bloqade/qasm2/passes/noise.py @@ -1,7 +1,7 @@ from dataclasses import field, dataclass from kirin import ir -from kirin.passes import Pass +from kirin.passes import Pass, HintConst from kirin.rewrite import ( Walk, Chain, @@ -10,11 +10,10 @@ DeadCodeElimination, CommonSubexpressionElimination, ) -from kirin.rewrite.abc import RewriteResult from bloqade.noise import native from bloqade.analysis import address -from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule +from bloqade.qasm2.rewrite.heuristic_noise import InsertGetQubit, NoiseRewriteRule @dataclass @@ -38,16 +37,17 @@ def __post_init__(self): self.address_analysis = address.AddressAnalysis(self.dialects) def unsafe_run(self, mt: ir.Method): - result = RewriteResult() - - frame, res = self.address_analysis.run_analysis(mt, no_raise=False) + result = Walk(InsertGetQubit()).rewrite(mt.code) + HintConst(self.dialects).unsafe_run(mt) + frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise) result = ( Walk( NoiseRewriteRule( address_analysis=frame.entries, noise_model=self.noise_model, gate_noise_params=self.gate_noise_params, - ) + ), + reverse=True, ) .rewrite(mt.code) .join(result) diff --git a/src/bloqade/qasm2/rewrite/heuristic_noise.py b/src/bloqade/qasm2/rewrite/heuristic_noise.py index 7556079f..edb06fed 100644 --- a/src/bloqade/qasm2/rewrite/heuristic_noise.py +++ b/src/bloqade/qasm2/rewrite/heuristic_noise.py @@ -10,6 +10,34 @@ from bloqade.qasm2.dialects import uop, core, glob, parallel +class InsertGetQubit(rewrite_abc.RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + if ( + not isinstance(node, core.QRegNew) + or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant) + or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int) + or (block := node.parent_block) is None + ): + return rewrite_abc.RewriteResult() + + n_qubits_stmt.detach() + node.detach() + if block.first_stmt is None: + block.stmts.append(node) + else: + node.insert_before(block.first_stmt) + n_qubits_stmt.insert_before(block.first_stmt) + + for idx_val in range(n_qubits): + idx = py.constant.Constant(value=idx_val) + qubit = core.QRegGet(node.result, idx=idx.result) + qubit.insert_after(node) + idx.insert_after(node) + + return rewrite_abc.RewriteResult(has_done_something=True) + + @dataclass class NoiseRewriteRule(rewrite_abc.RewriteRule): """ @@ -24,12 +52,18 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule): noise_model: native.MoveNoiseModelABC = field( default_factory=native.TwoRowZoneModel ) - qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False) + + def __post_init__(self): + self.qubit_ssa_value: Dict[int, ir.SSAValue] = {} + for ssa_value, addr in self.address_analysis.items(): + if ( + isinstance(addr, address.AddressQubit) + and ssa_value not in self.qubit_ssa_value + ): + self.qubit_ssa_value[addr.data] = ssa_value def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: - if isinstance(node, core.QRegNew): - return self.rewrite_qreg_new(node) - elif isinstance(node, uop.SingleQubitGate): + if isinstance(node, uop.SingleQubitGate): return self.rewrite_single_qubit_gate(node) elif isinstance(node, uop.CZ): return self.rewrite_cz_gate(node) @@ -42,24 +76,6 @@ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: else: return rewrite_abc.RewriteResult() - def rewrite_qreg_new(self, node: core.QRegNew): - - addr = self.address_analysis[node.result] - if not isinstance(addr, address.AddressReg): - return rewrite_abc.RewriteResult() - - has_done_something = False - for idx_val, qid in enumerate(addr.data): - if qid not in self.qubit_ssa_value: - has_done_something = True - idx = py.constant.Constant(value=idx_val) - qubit = core.QRegGet(node.result, idx=idx.result) - self.qubit_ssa_value[qid] = qubit.result - qubit.insert_after(node) - idx.insert_after(node) - - return rewrite_abc.RewriteResult(has_done_something=has_done_something) - def insert_single_qubit_noise( self, node: ir.Statement, diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index d7d1ee7e..3f546121 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -154,7 +154,7 @@ def __call__(self, node: ir.Statement) -> RewriteResult: self.group_has_merged[group_number] = result.has_done_something return result - if self.group_has_merged[group_number]: + if self.group_has_merged.setdefault(group_number, False): node.delete() return RewriteResult(has_done_something=self.group_has_merged[group_number])