From 9b87004b95960591b7de9c39cd548fe591a97c40 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 28 Apr 2025 11:48:26 -0400 Subject: [PATCH 1/4] fixing bug in parallel with reverse order --- src/bloqade/qasm2/rewrite/uop_to_parallel.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index d7d1ee7e..76e87277 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -58,6 +58,10 @@ class SimpleMergePolicy(MergePolicyABC): group_has_merged: Dict[int, bool] = field(default_factory=dict) """Mapping from group number to whether the group has been merged""" + def __post_init__(self): + for group_number in range(len(self.merge_groups)): + self.group_has_merged[group_number] = False + @staticmethod def same_id_checker(ssa1: ir.SSAValue, ssa2: ir.SSAValue): if ssa1 is ssa2: @@ -154,7 +158,7 @@ def __call__(self, node: ir.Statement) -> RewriteResult: self.group_has_merged[group_number] = result.has_done_something return result - if self.group_has_merged[group_number]: + if self.group_has_merged.setdefault(group_number, False): node.delete() return RewriteResult(has_done_something=self.group_has_merged[group_number]) From 09f47f8ddecb80f539a93519df5432afb4a3acd8 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 29 Apr 2025 11:05:39 -0400 Subject: [PATCH 2/4] fixing noise pass to work with new kirin patch --- src/bloqade/qasm2/passes/noise.py | 15 ++--- src/bloqade/qasm2/rewrite/heuristic_noise.py | 60 +++++++++++++------- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/src/bloqade/qasm2/passes/noise.py b/src/bloqade/qasm2/passes/noise.py index b91c1e51..cbaeac65 100644 --- a/src/bloqade/qasm2/passes/noise.py +++ b/src/bloqade/qasm2/passes/noise.py @@ -1,7 +1,7 @@ from dataclasses import field, dataclass from kirin import ir -from kirin.passes import Pass +from kirin.passes import Pass, HintConst from kirin.rewrite import ( Walk, Chain, @@ -10,11 +10,10 @@ DeadCodeElimination, CommonSubexpressionElimination, ) -from kirin.rewrite.abc import RewriteResult from bloqade.noise import native from bloqade.analysis import address -from bloqade.qasm2.rewrite.heuristic_noise import NoiseRewriteRule +from bloqade.qasm2.rewrite.heuristic_noise import InsertGetQubit, NoiseRewriteRule @dataclass @@ -38,16 +37,18 @@ def __post_init__(self): self.address_analysis = address.AddressAnalysis(self.dialects) def unsafe_run(self, mt: ir.Method): - result = RewriteResult() - - frame, res = self.address_analysis.run_analysis(mt, no_raise=False) + result = Walk(InsertGetQubit()).rewrite(mt.code) + mt.print() + HintConst(self.dialects).unsafe_run(mt) + frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise) result = ( Walk( NoiseRewriteRule( address_analysis=frame.entries, noise_model=self.noise_model, gate_noise_params=self.gate_noise_params, - ) + ), + reverse=True, ) .rewrite(mt.code) .join(result) diff --git a/src/bloqade/qasm2/rewrite/heuristic_noise.py b/src/bloqade/qasm2/rewrite/heuristic_noise.py index 7556079f..edb06fed 100644 --- a/src/bloqade/qasm2/rewrite/heuristic_noise.py +++ b/src/bloqade/qasm2/rewrite/heuristic_noise.py @@ -10,6 +10,34 @@ from bloqade.qasm2.dialects import uop, core, glob, parallel +class InsertGetQubit(rewrite_abc.RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + if ( + not isinstance(node, core.QRegNew) + or not isinstance(n_qubits_stmt := node.n_qubits.owner, py.Constant) + or not isinstance(n_qubits := n_qubits_stmt.value.unwrap(), int) + or (block := node.parent_block) is None + ): + return rewrite_abc.RewriteResult() + + n_qubits_stmt.detach() + node.detach() + if block.first_stmt is None: + block.stmts.append(node) + else: + node.insert_before(block.first_stmt) + n_qubits_stmt.insert_before(block.first_stmt) + + for idx_val in range(n_qubits): + idx = py.constant.Constant(value=idx_val) + qubit = core.QRegGet(node.result, idx=idx.result) + qubit.insert_after(node) + idx.insert_after(node) + + return rewrite_abc.RewriteResult(has_done_something=True) + + @dataclass class NoiseRewriteRule(rewrite_abc.RewriteRule): """ @@ -24,12 +52,18 @@ class NoiseRewriteRule(rewrite_abc.RewriteRule): noise_model: native.MoveNoiseModelABC = field( default_factory=native.TwoRowZoneModel ) - qubit_ssa_value: Dict[int, ir.SSAValue] = field(default_factory=dict, init=False) + + def __post_init__(self): + self.qubit_ssa_value: Dict[int, ir.SSAValue] = {} + for ssa_value, addr in self.address_analysis.items(): + if ( + isinstance(addr, address.AddressQubit) + and ssa_value not in self.qubit_ssa_value + ): + self.qubit_ssa_value[addr.data] = ssa_value def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: - if isinstance(node, core.QRegNew): - return self.rewrite_qreg_new(node) - elif isinstance(node, uop.SingleQubitGate): + if isinstance(node, uop.SingleQubitGate): return self.rewrite_single_qubit_gate(node) elif isinstance(node, uop.CZ): return self.rewrite_cz_gate(node) @@ -42,24 +76,6 @@ def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: else: return rewrite_abc.RewriteResult() - def rewrite_qreg_new(self, node: core.QRegNew): - - addr = self.address_analysis[node.result] - if not isinstance(addr, address.AddressReg): - return rewrite_abc.RewriteResult() - - has_done_something = False - for idx_val, qid in enumerate(addr.data): - if qid not in self.qubit_ssa_value: - has_done_something = True - idx = py.constant.Constant(value=idx_val) - qubit = core.QRegGet(node.result, idx=idx.result) - self.qubit_ssa_value[qid] = qubit.result - qubit.insert_after(node) - idx.insert_after(node) - - return rewrite_abc.RewriteResult(has_done_something=has_done_something) - def insert_single_qubit_noise( self, node: ir.Statement, From d43f43e5b2e240c6ced1e68c850e63d31e772a10 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 29 Apr 2025 11:07:21 -0400 Subject: [PATCH 3/4] removing print --- src/bloqade/qasm2/passes/noise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bloqade/qasm2/passes/noise.py b/src/bloqade/qasm2/passes/noise.py index cbaeac65..477bd62a 100644 --- a/src/bloqade/qasm2/passes/noise.py +++ b/src/bloqade/qasm2/passes/noise.py @@ -38,7 +38,6 @@ def __post_init__(self): def unsafe_run(self, mt: ir.Method): result = Walk(InsertGetQubit()).rewrite(mt.code) - mt.print() HintConst(self.dialects).unsafe_run(mt) frame, _ = self.address_analysis.run_analysis(mt, no_raise=self.no_raise) result = ( From a5e5ceccad732deb2bf6751a5a9bcb14178a3784 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 29 Apr 2025 11:09:22 -0400 Subject: [PATCH 4/4] Removing uneeded fix --- src/bloqade/qasm2/rewrite/uop_to_parallel.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index 76e87277..3f546121 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -58,10 +58,6 @@ class SimpleMergePolicy(MergePolicyABC): group_has_merged: Dict[int, bool] = field(default_factory=dict) """Mapping from group number to whether the group has been merged""" - def __post_init__(self): - for group_number in range(len(self.merge_groups)): - self.group_has_merged[group_number] = False - @staticmethod def same_id_checker(ssa1: ir.SSAValue, ssa2: ir.SSAValue): if ssa1 is ssa2: