From b71828243edd70c42763df3105a80a15c04ce186 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 10:26:24 +0100 Subject: [PATCH 01/25] Update analysis to new kirin API --- src/bloqade/analysis/address/analysis.py | 32 ++++++++++++--------- src/bloqade/analysis/measure_id/analysis.py | 16 +++++------ src/bloqade/analysis/measure_id/impls.py | 9 +++--- test/analysis/measure_id/test_measure_id.py | 3 ++ 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 1f78cfdb..23c08538 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -15,7 +15,7 @@ class AddressAnalysis(Forward[Address]): This analysis pass can be used to track the global addresses of qubits and wires. """ - keys = ["qubit.address"] + keys = ("qubit.address",) _const_prop: const.Propagate lattice = Address next_address: int = field(init=False) @@ -45,7 +45,7 @@ def try_eval_const_prop( ) -> interp.StatementResult[Address]: _frame = self._const_prop.initialize_frame(frame.code) _frame.set_values(stmt.args, tuple(x.result for x in args)) - result = self._const_prop.eval_stmt(_frame, stmt) + result = self._const_prop.frame_eval(_frame, stmt) match result: case interp.ReturnValue(constant_ret): @@ -117,9 +117,12 @@ def run_lattice( ) return ret case ConstResult(const.Value(ir.Method() as method)): - _, ret = self.run_method( - method, - self.permute_values(method.arg_names, inputs, kwargs), + _, ret = self.call( + method.code, + *inputs, + # **kwargs, + # **{k: v for k, v in zip(kwargs, frame.get_values(stmt.kwargs))}, + # self.permute_values(method.arg_names, inputs, kwargs), ) return ret case _: @@ -137,14 +140,17 @@ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None: return value - def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement): - args = frame.get_values(stmt.args) + def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement): + args = frame.get_values(node.args) if types.is_tuple_of(args, ConstResult): - return self.try_eval_const_prop(frame, stmt, args) + return self.try_eval_const_prop(frame, node, args) - return tuple(Address.from_type(result.type) for result in stmt.results) + return tuple(Address.from_type(result.type) for result in node.results) - def run_method(self, method: ir.Method, args: tuple[Address, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - self_mt = ConstResult(const.Value(method)) - return self.run_callable(method.code, (self_mt,) + args) + # def run(self, method: ir.Method, *args: Address, **kwargs): + # # NOTE: we do not support dynamic calls here, thus no need to propagate method object + # self_mt = ConstResult(const.Value(method)) + # return self.call(method.code, self_mt, *args, **kwargs) + + def method_self(self, method: ir.Method) -> Address: + return ConstResult(const.Value(method)) diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index f2d5e9f3..1ebf0387 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -22,20 +22,20 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]): measure_count = 0 def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> MeasureIDFrame: - return MeasureIDFrame(code, has_parent_access=has_parent_access) + return MeasureIDFrame(node, has_parent_access=has_parent_access) # Still default to bottom, # but let constants return the softer "NoMeasureId" type from impl - def eval_stmt_fallback( - self, frame: ForwardFrame[MeasureId], stmt: ir.Statement + def eval_fallback( + self, frame: ForwardFrame[MeasureId], node: ir.Statement ) -> tuple[MeasureId, ...]: - return tuple(NotMeasureId() for _ in stmt.results) + return tuple(NotMeasureId() for _ in node.results) - def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + # def run(self, method: ir.Method, *args: MeasureId, **kwargs): + # # NOTE: we do not support dynamic calls here, thus no need to propagate method object + # return self.call(method.code, self.lattice.bottom(), *args, **kwargs) # Xiu-zhe (Roger) Luo came up with this in the address analysis, # reused here for convenience (now modified to be a bit more graceful) diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 993b97bd..439ae2a6 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -138,11 +138,10 @@ def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Retu def invoke( self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke ): - _, ret = interp_.run_method( - stmt.callee, - interp_.permute_values( - stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs - ), + _, ret = interp_.call( + stmt.callee.code, + interp_.method_self(stmt.callee), + *frame.get_values(stmt.inputs), ) return (ret,) diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index 2a4f97c8..fc124fb2 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -278,3 +278,6 @@ def test(): assert [frame.entries[result] for result in results_at(test, 0, 6)] == [ InvalidMeasureId() ] + + +test_slice() From 2fed75d9be497e49274e1fc75227698c0fb36851 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 16:09:59 +0100 Subject: [PATCH 02/25] Fix fidelity analysis --- src/bloqade/analysis/fidelity/analysis.py | 30 +++++++------------- src/bloqade/qasm2/dialects/noise/fidelity.py | 4 +-- test/analysis/fidelity/test_fidelity.py | 16 +++++------ test/analysis/measure_id/test_measure_id.py | 3 -- 4 files changed, 20 insertions(+), 33 deletions(-) diff --git a/src/bloqade/analysis/fidelity/analysis.py b/src/bloqade/analysis/fidelity/analysis.py index f1ad252f..503cc8e9 100644 --- a/src/bloqade/analysis/fidelity/analysis.py +++ b/src/bloqade/analysis/fidelity/analysis.py @@ -4,7 +4,6 @@ from kirin import ir from kirin.lattice import EmptyLattice from kirin.analysis import Forward -from kirin.interp.value import Successor from kirin.analysis.forward import ForwardFrame from ..address import Address, AddressAnalysis @@ -48,15 +47,11 @@ def main(): The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered. """ - _current_gate_fidelity: float = field(init=False) - atom_survival_probability: list[float] = field(init=False) """ The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register. """ - _current_atom_survival_probability: list[float] = field(init=False) - addr_frame: ForwardFrame[Address] = field(init=False) def initialize(self): @@ -67,27 +62,24 @@ def initialize(self): ] return self - def posthook_succ(self, frame: ForwardFrame, succ: Successor): - self.gate_fidelity *= self._current_gate_fidelity - for i, _current_survival in enumerate(self._current_atom_survival_probability): - self.atom_survival_probability[i] *= _current_survival - - def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): + def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): # NOTE: default is to conserve fidelity, so do nothing here return - def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]): - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]: + self._run_address_analysis(method, no_raise=False) + return super().run(method, *args, **kwargs) - def run_analysis( - self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True - ) -> tuple[ForwardFrame, Any]: - self._run_address_analysis(method, no_raise=no_raise) - return super().run(method) + def run_no_raise(self, method: ir.Method, *args: Any, **kwargs: Any): + self._run_address_analysis(method, no_raise=True) + return super().run_no_raise(method, *args, **kwargs) def _run_address_analysis(self, method: ir.Method, no_raise: bool): addr_analysis = AddressAnalysis(self.dialects) - addr_frame, _ = addr_analysis.run(method=method) + if no_raise: + addr_frame, _ = addr_analysis.run_no_raise(method=method) + else: + addr_frame, _ = addr_analysis.run(method=method) self.addr_frame = addr_frame # NOTE: make sure we have as many probabilities as we have addresses diff --git a/src/bloqade/qasm2/dialects/noise/fidelity.py b/src/bloqade/qasm2/dialects/noise/fidelity.py index f7ed75c4..acd17ac9 100644 --- a/src/bloqade/qasm2/dialects/noise/fidelity.py +++ b/src/bloqade/qasm2/dialects/noise/fidelity.py @@ -31,7 +31,7 @@ def pauli_channel( # NOTE: fidelity is just the inverse probability of any noise to occur fid = (1 - p) * (1 - p_ctrl) - interp._current_gate_fidelity *= fid + interp.gate_fidelity *= fid @interp.impl(AtomLossChannel) def atom_loss( @@ -44,4 +44,4 @@ def atom_loss( addresses = interp.addr_frame.get(stmt.qargs) # NOTE: get the corresponding index and reduce survival probability accordingly for index in addresses.data: - interp._current_atom_survival_probability[index] *= 1 - stmt.prob + interp.atom_survival_probability[index] *= 1 - stmt.prob diff --git a/test/analysis/fidelity/test_fidelity.py b/test/analysis/fidelity/test_fidelity.py index cbb39845..78ca8f26 100644 --- a/test/analysis/fidelity/test_fidelity.py +++ b/test/analysis/fidelity/test_fidelity.py @@ -19,7 +19,7 @@ def main(): return q fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity == 1 assert fid_analysis.atom_survival_probability[0] == 1 - p_loss @@ -49,11 +49,10 @@ def main(): main.print() fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) expected_fidelity = (1 - 3 * p_ch) ** 2 - assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity assert math.isclose(fid_analysis.gate_fidelity, expected_fidelity) @@ -69,11 +68,10 @@ def main(): main.print() fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) expected_fidelity = 1 - 3 * p_ch - assert fid_analysis.gate_fidelity == fid_analysis._current_gate_fidelity assert math.isclose(fid_analysis.gate_fidelity, expected_fidelity) @@ -123,12 +121,12 @@ def main_if(): ) NoisePass(main.dialects, noise_model=model)(main) fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) model = NoiseTestModel() NoisePass(main_if.dialects, noise_model=model)(main_if) fid_if_analysis = FidelityAnalysis(main_if.dialects) - fid_if_analysis.run_analysis(main_if, no_raise=False) + fid_if_analysis.run(main_if) assert 0 < fid_if_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 assert ( @@ -186,7 +184,7 @@ def main_for(): ) NoisePass(main.dialects, noise_model=model)(main) fid_analysis = FidelityAnalysis(main.dialects) - fid_analysis.run_analysis(main, no_raise=False) + fid_analysis.run(main) model = NoiseTestModel() NoisePass(main_for.dialects, noise_model=model)(main_for) @@ -194,7 +192,7 @@ def main_for(): main_for.print() fid_for_analysis = FidelityAnalysis(main_for.dialects) - fid_for_analysis.run_analysis(main_for, no_raise=False) + fid_for_analysis.run(main_for) assert 0 < fid_for_analysis.gate_fidelity == fid_analysis.gate_fidelity < 1 assert ( diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index fc124fb2..2a4f97c8 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -278,6 +278,3 @@ def test(): assert [frame.entries[result] for result in results_at(test, 0, 6)] == [ InvalidMeasureId() ] - - -test_slice() From 41feea8e283b23dfbf8dfc0376c970bfb3cdba75 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 16:17:22 +0100 Subject: [PATCH 03/25] Fix address analysis --- src/bloqade/analysis/address/analysis.py | 1 + src/bloqade/analysis/address/impls.py | 41 ++++++++++---------- test/analysis/address/test_qubit_analysis.py | 25 ++++++------ 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 23c08538..973e2d90 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -119,6 +119,7 @@ def run_lattice( case ConstResult(const.Value(ir.Method() as method)): _, ret = self.call( method.code, + self.method_self(method), *inputs, # **kwargs, # **{k: v for k, v in zip(kwargs, frame.get_values(stmt.kwargs))}, diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index 1a89bb3e..9df7d56f 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -180,13 +180,10 @@ def invoke( frame: ForwardFrame[Address], stmt: func.Invoke, ): - - args = interp_.permute_values( - stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs - ) - _, ret = interp_.run_method( - stmt.callee, - args, + _, ret = interp_.call( + stmt.callee.code, + interp_.method_self(stmt.callee), + *frame.get_values(stmt.inputs), ) return (ret,) @@ -319,26 +316,28 @@ def ifelse( ): body = stmt.then_body if const_cond.data else stmt.else_body with interp_.new_frame(stmt, has_parent_access=True) as body_frame: - ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,)) + ret = interp_.frame_call_region(body_frame, stmt, body, address_cond) # interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values()) return ret else: # run both branches with interp_.new_frame(stmt, has_parent_access=True) as then_frame: - then_results = interp_.run_ssacfg_region( - then_frame, stmt.then_body, (address_cond,) - ) - interp_.set_values( - frame, then_frame.entries.keys(), then_frame.entries.values() + then_results = interp_.frame_call_region( + then_frame, + stmt, + stmt.then_body, + address_cond, ) + frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) with interp_.new_frame(stmt, has_parent_access=True) as else_frame: - else_results = interp_.run_ssacfg_region( - else_frame, stmt.else_body, (address_cond,) - ) - interp_.set_values( - frame, else_frame.entries.keys(), else_frame.entries.values() + else_results = interp_.frame_call_region( + else_frame, + stmt, + stmt.else_body, + address_cond, ) + frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) # TODO: pick the non-return value if isinstance(then_results, interp.ReturnValue) and isinstance( else_results, interp.ReturnValue @@ -364,12 +363,12 @@ def for_loop( iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable)) if iter_type is None: - return interp_.eval_stmt_fallback(frame, stmt) + return interp_.eval_fallback(frame, stmt) for value in iterable: with interp_.new_frame(stmt, has_parent_access=True) as body_frame: - loop_vars = interp_.run_ssacfg_region( - body_frame, stmt.body, (value,) + loop_vars + loop_vars = interp_.frame_call_region( + body_frame, stmt, stmt.body, value, *loop_vars ) if loop_vars is None: diff --git a/test/analysis/address/test_qubit_analysis.py b/test/analysis/address/test_qubit_analysis.py index 1866c9a1..dddf825a 100644 --- a/test/analysis/address/test_qubit_analysis.py +++ b/test/analysis/address/test_qubit_analysis.py @@ -21,7 +21,7 @@ def test(): return (q1[1], q2) address_analysis = address.AddressAnalysis(test.dialects) - frame, _ = address_analysis.run_analysis(test, no_raise=False) + frame, _ = address_analysis.run(test) address_types = collect_address_types(frame, address.PartialTuple) test.print(analysis=frame.entries) @@ -116,7 +116,7 @@ def main(): return q address_analysis = address.AddressAnalysis(main.dialects) - address_analysis.run_analysis(main, no_raise=False) + address_analysis.run(main) def test_new_qubit(): @@ -125,7 +125,7 @@ def main(): return squin.qubit.new() address_analysis = address.AddressAnalysis(main.dialects) - _, result = address_analysis.run_analysis(main, no_raise=False) + _, result = address_analysis.run(main) assert result == address.AddressQubit(0) @@ -139,8 +139,9 @@ def main(n: int): return qreg address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis( - main, args=(address.ConstResult(const.Unknown()),), no_raise=False + frame, result = address_analysis.run( + main, + address.ConstResult(const.Unknown()), ) assert result == address.AddressReg(data=tuple(range(4))) @@ -155,7 +156,7 @@ def main(n: int): return qreg address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.AddressReg(data=tuple(range(4))) @@ -165,7 +166,7 @@ def main(n: int): return (0, 1) + (2, n) address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.PartialTuple( data=( @@ -183,7 +184,7 @@ def main(n: int): return (0, 1) + [2, n] # type: ignore address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.Bottom() @@ -194,7 +195,7 @@ def main(n: tuple[int, ...]): return (0, 1) + n address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.Unknown() @@ -207,7 +208,7 @@ def main(q: qubit.Qubit): return (0, q, 2, q)[1::2] address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) assert result == address.UnknownReg() @@ -219,7 +220,7 @@ def main(n: int): main.print() address_analysis = address.AddressAnalysis(main.dialects) - frame, result = address_analysis.run_analysis(main, no_raise=False) + frame, result = address_analysis.run(main) main.print(analysis=frame.entries) assert ( result == address.UnknownReg() @@ -260,7 +261,7 @@ def main(): func = main analysis = address.AddressAnalysis(squin.kernel) - _, ret = analysis.run_analysis(func, no_raise=False) + _, ret = analysis.run(func) assert ret == address.AddressReg(data=tuple(range(20))) assert analysis.qubit_count == 20 From 9c57ac1af8766e34f55bba4ff4d3fe48b48ea877 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 16:21:16 +0100 Subject: [PATCH 04/25] Fix PyQrack interpreter --- src/bloqade/pyqrack/task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bloqade/pyqrack/task.py b/src/bloqade/pyqrack/task.py index 1502f430..fdfbf728 100644 --- a/src/bloqade/pyqrack/task.py +++ b/src/bloqade/pyqrack/task.py @@ -28,8 +28,8 @@ def run(self) -> RetType: RetType, self.pyqrack_interp.run( self.kernel, - args=self.args, - kwargs=self.kwargs, + *self.args, + **self.kwargs, ), ) From 8985457e1dc6554d9cb1bf9e786d51293d33e6ef Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 16:25:24 +0100 Subject: [PATCH 05/25] Fix PyQrack return value --- src/bloqade/pyqrack/task.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/bloqade/pyqrack/task.py b/src/bloqade/pyqrack/task.py index fdfbf728..0acb6ef0 100644 --- a/src/bloqade/pyqrack/task.py +++ b/src/bloqade/pyqrack/task.py @@ -24,14 +24,12 @@ class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]): pyqrack_interp: PyQrackInterpreter[MemoryType] def run(self) -> RetType: - return cast( - RetType, - self.pyqrack_interp.run( - self.kernel, - *self.args, - **self.kwargs, - ), + _, ret = self.pyqrack_interp.run( + self.kernel, + *self.args, + **self.kwargs, ) + return cast(RetType, ret) @property def state(self) -> MemoryType: From e999074ba708223fc25c5572b02be044f40cba84 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 16:33:31 +0100 Subject: [PATCH 06/25] Discard gate dialect AFTER rewrite to native is done --- src/bloqade/native/upstream/squin2native.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/bloqade/native/upstream/squin2native.py b/src/bloqade/native/upstream/squin2native.py index 2a9131f2..998d34b6 100644 --- a/src/bloqade/native/upstream/squin2native.py +++ b/src/bloqade/native/upstream/squin2native.py @@ -62,16 +62,18 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method: all_dialects = chain.from_iterable( ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers ) - new_dialects = ( - mt.dialects.union(all_dialects).discard(gate_dialect).union(kernel) - ) + combined_dialects = mt.dialects.union(all_dialects).union(kernel) - out = mt.similar(new_dialects) - UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out) - CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out) - # verify all kernels in the callgraph + out = mt.similar(combined_dialects) + UpdateDialectsOnCallGraph(combined_dialects, no_raise=no_raise)(out) + CallGraphPass(combined_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)( + out + ) + # verify all kernels in the callgraph and discard gate dialect + out.dialects.discard(gate_dialect) new_callgraph = CallGraph(out) for ker in new_callgraph.edges.keys(): + ker.dialects.discard(gate_dialect) ker.verify() return out From d0ff2d758fdcd22f851ff953ec6f81584712b22d Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 16:38:29 +0100 Subject: [PATCH 07/25] Fix PyQrack target --- src/bloqade/pyqrack/target.py | 3 ++- test/pyqrack/runtime/noise/qasm2/test_loss.py | 8 +++----- test/pyqrack/runtime/noise/qasm2/test_pauli.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/bloqade/pyqrack/target.py b/src/bloqade/pyqrack/target.py index e9f21233..54700933 100644 --- a/src/bloqade/pyqrack/target.py +++ b/src/bloqade/pyqrack/target.py @@ -87,7 +87,8 @@ def run( """ fold = Fold(mt.dialects) fold(mt) - return self._get_interp(mt).run(mt, args, kwargs) + _, ret = self._get_interp(mt).run(mt, *args, **kwargs) + return ret def multi_run( self, diff --git a/test/pyqrack/runtime/noise/qasm2/test_loss.py b/test/pyqrack/runtime/noise/qasm2/test_loss.py index 4c1d19a1..29214c09 100644 --- a/test/pyqrack/runtime/noise/qasm2/test_loss.py +++ b/test/pyqrack/runtime/noise/qasm2/test_loss.py @@ -1,12 +1,10 @@ -from typing import Literal from unittest.mock import Mock from kirin import ir -from kirin.dialects import ilist from bloqade import qasm2 from bloqade.qasm2 import noise -from bloqade.pyqrack import PyQrackQubit, PyQrackInterpreter, reg +from bloqade.pyqrack import PyQrackInterpreter, reg from bloqade.pyqrack.base import MockMemory @@ -34,9 +32,9 @@ def test_atom_loss(c: qasm2.CReg): input = reg.CRegister(1) memory = MockMemory() - result: ilist.IList[PyQrackQubit, Literal[2]] = PyQrackInterpreter( + _, result = PyQrackInterpreter( qasm2.extended, memory=memory, rng_state=rng_state - ).run(test_atom_loss, (input,)) + ).run(test_atom_loss, input) assert result[0].state is reg.QubitState.Lost assert result[1].state is reg.QubitState.Active diff --git a/test/pyqrack/runtime/noise/qasm2/test_pauli.py b/test/pyqrack/runtime/noise/qasm2/test_pauli.py index 04541d23..2f9f2bb9 100644 --- a/test/pyqrack/runtime/noise/qasm2/test_pauli.py +++ b/test/pyqrack/runtime/noise/qasm2/test_pauli.py @@ -11,7 +11,7 @@ def run_mock(program: ir.Method, rng_state: Mock | None = None): PyQrackInterpreter( program.dialects, memory=(memory := MockMemory()), rng_state=rng_state - ).run(program, ()) + ).run(program) assert isinstance(mock := memory.sim_reg, Mock) return mock From 6e81083f887f4790261f94ed253aba83ebf2727f Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 16:39:28 +0100 Subject: [PATCH 08/25] Fix self arg in qasm2 loading --- src/bloqade/qasm2/_qasm_loading.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bloqade/qasm2/_qasm_loading.py b/src/bloqade/qasm2/_qasm_loading.py index 63ffcd5f..f425cd08 100644 --- a/src/bloqade/qasm2/_qasm_loading.py +++ b/src/bloqade/qasm2/_qasm_loading.py @@ -82,11 +82,11 @@ def loads( body=body, ) + self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument + body.blocks[0]._args = (self_arg,) + mt = ir.Method( - mod=None, - py_func=None, sym_name=kernel_name, - arg_names=[], dialects=qasm2_lowering.dialects, code=code, ) From a4dd59226d6ab140d2c583b71dfc47dcd412fa58 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 31 Oct 2025 17:09:31 +0100 Subject: [PATCH 09/25] All PyQrack tests work --- test/pyqrack/runtime/test_qrack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/pyqrack/runtime/test_qrack.py b/test/pyqrack/runtime/test_qrack.py index 9161cabf..e6174797 100644 --- a/test/pyqrack/runtime/test_qrack.py +++ b/test/pyqrack/runtime/test_qrack.py @@ -14,7 +14,7 @@ def run_mock(program: ir.Method, rng_state: Mock | None = None): PyQrackInterpreter( program.dialects, memory=(memory := MockMemory()), rng_state=rng_state - ).run(program, ()) + ).run(program) assert isinstance(mock := memory.sim_reg, Mock) return mock From 0aed338d88ac0bc10c2a1ac7879220905527579c Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 3 Nov 2025 09:57:03 +0100 Subject: [PATCH 10/25] Fix qasm2 parallel & global rewrites --- src/bloqade/qasm2/emit/target.py | 12 ++++++------ test/qasm2/emit/test_qasm2_emit.py | 5 ----- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/bloqade/qasm2/emit/target.py b/src/bloqade/qasm2/emit/target.py index a2548bba..784034a4 100644 --- a/src/bloqade/qasm2/emit/target.py +++ b/src/bloqade/qasm2/emit/target.py @@ -106,13 +106,13 @@ def emit(self, entry: ir.Method) -> ast.MainProgram: unroll_ifs=self.unroll_ifs, ).fixpoint(entry) - # if not self.allow_global: - # # rewrite global to parallel - # GlobalToParallel(dialects=entry.dialects)(entry) + if not self.allow_global: + # rewrite global to parallel + GlobalToParallel(dialects=entry.dialects)(entry) - # if not self.allow_parallel: - # # rewrite parallel to uop - # ParallelToUOp(dialects=entry.dialects)(entry) + if not self.allow_parallel: + # rewrite parallel to uop + ParallelToUOp(dialects=entry.dialects)(entry) Py2QASM(entry.dialects)(entry) target_main = EmitQASM2Main(self.main_target).initialize() diff --git a/test/qasm2/emit/test_qasm2_emit.py b/test/qasm2/emit/test_qasm2_emit.py index 00eaaec4..f7e776eb 100644 --- a/test/qasm2/emit/test_qasm2_emit.py +++ b/test/qasm2/emit/test_qasm2_emit.py @@ -54,7 +54,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global(): @qasm2.extended @@ -85,7 +84,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_global_allow_para(): @qasm2.extended @@ -118,7 +116,6 @@ def glob_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para(): @qasm2.extended @@ -145,7 +142,6 @@ def para_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_para(): @qasm2.extended @@ -201,7 +197,6 @@ def para_u(): ) -@pytest.mark.xfail(reason="Global and Parallel refactor still ongoing") def test_para_allow_global(): @qasm2.extended From cfd1fedce690cf824724dcb7188b85fbaa577f9d Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 3 Nov 2025 11:18:54 +0100 Subject: [PATCH 11/25] Fix types in qbraid lowering --- src/bloqade/qasm2/dialects/expr/stmts.py | 14 +++++++------- src/bloqade/qbraid/lowering.py | 1 + src/bloqade/qbraid/schema.py | 4 ++-- test/qbraid/test_lowering.py | 21 ++++++++++++++------- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/bloqade/qasm2/dialects/expr/stmts.py b/src/bloqade/qasm2/dialects/expr/stmts.py index e2e130e6..fad08b33 100644 --- a/src/bloqade/qasm2/dialects/expr/stmts.py +++ b/src/bloqade/qasm2/dialects/expr/stmts.py @@ -87,7 +87,7 @@ def print_impl(self, printer: Printer) -> None: # QASM 2.0 arithmetic operations -PyNum = types.Union(types.Int, types.Float) +PyNum = types.TypeVar("PyNum", bound=types.Union(types.Int, types.Float)) @statement(dialect=dialect) @@ -110,7 +110,7 @@ class Sin(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the sine of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The sine of the number.""" @@ -122,7 +122,7 @@ class Cos(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the cosine of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The cosine of the number.""" @@ -134,7 +134,7 @@ class Tan(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the tangent of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The tangent of the number.""" @@ -146,7 +146,7 @@ class Exp(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the exponential of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The exponential of the number.""" @@ -158,7 +158,7 @@ class Log(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the natural log of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The natural log of the number.""" @@ -170,7 +170,7 @@ class Sqrt(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the square root of.""" - result: ir.ResultValue = info.result(PyNum) + result: ir.ResultValue = info.result(types.Float) """result (float): The square root of the number.""" diff --git a/src/bloqade/qbraid/lowering.py b/src/bloqade/qbraid/lowering.py index 15958071..276e62ff 100644 --- a/src/bloqade/qbraid/lowering.py +++ b/src/bloqade/qbraid/lowering.py @@ -320,5 +320,6 @@ def lower_full_turns(self, value: float) -> ir.SSAValue: self.block_list.append(const_pi) turns = self.lower_number(2 * value) mul = qasm2.expr.Mul(const_pi.result, turns) + mul.result.type = types.Float self.block_list.append(mul) return mul.result diff --git a/src/bloqade/qbraid/schema.py b/src/bloqade/qbraid/schema.py index 54eed1c2..450d4f5a 100644 --- a/src/bloqade/qbraid/schema.py +++ b/src/bloqade/qbraid/schema.py @@ -238,13 +238,13 @@ def decompiled_circuit(self) -> str: str: The decompiled circuit from hardware execution. """ - from bloqade.noise import native from bloqade.qasm2.emit import QASM2 from bloqade.qasm2.passes import glob, parallel + from bloqade.qasm2.rewrite.noise import remove_noise mt = self.lower_noise_model("method") - native.RemoveNoisePass(mt.dialects)(mt) + remove_noise.RemoveNoisePass(mt.dialects)(mt) parallel.ParallelToUOp(mt.dialects)(mt) glob.GlobalToUOP(mt.dialects)(mt) return QASM2(qelib1=True).emit_str(mt) diff --git a/test/qbraid/test_lowering.py b/test/qbraid/test_lowering.py index 323972a8..cabca44e 100644 --- a/test/qbraid/test_lowering.py +++ b/test/qbraid/test_lowering.py @@ -56,11 +56,8 @@ def run_assert(noise_model: schema.NoiseModel, expected_stmts: List[ir.Statement ) expected_mt = ir.Method( - mod=None, - py_func=None, dialects=lowering.qbraid_noise, sym_name="test", - arg_names=[], code=expected_func_stmt, ) @@ -242,7 +239,10 @@ def test_lowering_global_w(): (lam_num := as_float(2 * -(0.5 + phi_val))), (lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)), parallel.UGate( - theta=theta.result, phi=phi.result, lam=lam.result, qargs=qargs.result + theta=ir.ResultValue(theta, 0, type=types.Float), + phi=ir.ResultValue(phi, 0, type=types.Float), + lam=ir.ResultValue(lam, 0, type=types.Float), + qargs=qargs.result, ), func.Return(creg.result), ] @@ -304,7 +304,10 @@ def test_lowering_local_w(): (lam_num := as_float(2 * -(0.5 + phi_val))), (lam := qasm2.expr.Mul(pi_lam.result, lam_num.result)), parallel.UGate( - qargs=qargs.result, theta=theta.result, phi=phi.result, lam=lam.result + qargs=qargs.result, + theta=ir.ResultValue(theta, 0, type=types.Float), + phi=ir.ResultValue(phi, 0, type=types.Float), + lam=ir.ResultValue(lam, 0, type=types.Float), ), func.Return(creg.result), ] @@ -348,7 +351,9 @@ def test_lowering_global_rz(): (theta_pi := qasm2.expr.ConstPI()), (theta_num := as_float(2 * phi_val)), (theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)), - parallel.RZ(theta=theta.result, qargs=qargs.result), + parallel.RZ( + theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result + ), func.Return(creg.result), ] @@ -401,7 +406,9 @@ def test_lowering_local_rz(): (theta_pi := qasm2.expr.ConstPI()), (theta_num := as_float(2 * phi_val)), (theta := qasm2.expr.Mul(theta_pi.result, theta_num.result)), - parallel.RZ(theta=theta.result, qargs=qargs.result), + parallel.RZ( + theta=ir.ResultValue(theta, 0, type=types.Float), qargs=qargs.result + ), func.Return(creg.result), ] From 3b2fc6e7319718610c59ce6c10976f4e775c55b4 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 3 Nov 2025 11:27:51 +0100 Subject: [PATCH 12/25] Unmark & fix a bunch of qasm2 tests --- src/bloqade/qasm2/parse/lowering.py | 1 - src/bloqade/qasm2/rewrite/uop_to_parallel.py | 2 +- src/bloqade/squin/analysis/schedule.py | 15 +++++++-------- test/qasm2/passes/test_global_to_parallel.py | 3 --- test/qasm2/passes/test_global_to_uop.py | 2 -- test/qasm2/passes/test_heuristic_noise.py | 2 -- test/qasm2/passes/test_parallel_to_global.py | 7 ------- test/qasm2/passes/test_parallel_to_uop.py | 2 -- test/qasm2/passes/test_uop_to_parallel.py | 5 ----- test/qasm2/test_count.py | 20 +++++++------------- test/qasm2/test_lowering.py | 7 ------- test/qasm2/test_native.py | 2 -- 12 files changed, 15 insertions(+), 53 deletions(-) diff --git a/src/bloqade/qasm2/parse/lowering.py b/src/bloqade/qasm2/parse/lowering.py index 765d1eb3..07d66d28 100644 --- a/src/bloqade/qasm2/parse/lowering.py +++ b/src/bloqade/qasm2/parse/lowering.py @@ -450,7 +450,6 @@ def visit_Instruction(self, state: lowering.State[ast.Node], node: ast.Instructi func.Invoke( callee=value, inputs=tuple(params + qargs), - kwargs=tuple(), ) ) diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index b1c102ed..5c85376b 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -66,7 +66,7 @@ def same_id_checker(ssa1: ir.SSAValue, ssa2: ir.SSAValue): assert isinstance(hint1, lattice.Result) and isinstance( hint2, lattice.Result ) - return hint1.is_equal(hint2) + return hint1.is_structurally_equal(hint2) else: return False diff --git a/src/bloqade/squin/analysis/schedule.py b/src/bloqade/squin/analysis/schedule.py index e99e219d..35487e08 100644 --- a/src/bloqade/squin/analysis/schedule.py +++ b/src/bloqade/squin/analysis/schedule.py @@ -185,18 +185,17 @@ def push_current_dag(self, block: ir.Block): self.stmt_dag = StmtDag() self.use_def = {} - def run_method(self, method: ir.Method, args: tuple[GateSchedule, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + def method_self(self, method: ir.Method) -> GateSchedule: + return self.lattice.bottom() - def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): - if stmt.has_trait(ir.IsTerminator): + def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): + if node.has_trait(ir.IsTerminator): assert ( - stmt.parent_block is not None + node.parent_block is not None ), "Terminator statement has no parent block" - self.push_current_dag(stmt.parent_block) + self.push_current_dag(node.parent_block) - return tuple(self.lattice.top() for _ in stmt.results) + return tuple(self.lattice.top() for _ in node.results) def _update_dag(self, stmt: ir.Statement, addr: address.Address): if isinstance(addr, address.AddressQubit): diff --git a/test/qasm2/passes/test_global_to_parallel.py b/test/qasm2/passes/test_global_to_parallel.py index b77a3a13..3b147636 100644 --- a/test/qasm2/passes/test_global_to_parallel.py +++ b/test/qasm2/passes/test_global_to_parallel.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func, ilist @@ -18,7 +17,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_global2para_rewrite(): @qasm2.extended @@ -79,7 +77,6 @@ def main(): assert_methods(expected_method, main) -@pytest.mark.xfail def test_global2para_rewrite2(): @qasm2.extended diff --git a/test/qasm2/passes/test_global_to_uop.py b/test/qasm2/passes/test_global_to_uop.py index 9be187d9..dafe6f2a 100644 --- a/test/qasm2/passes/test_global_to_uop.py +++ b/test/qasm2/passes/test_global_to_uop.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.rewrite import Walk, Fixpoint, CommonSubexpressionElimination from kirin.dialects import py, func @@ -18,7 +17,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_global_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_heuristic_noise.py b/test/qasm2/passes/test_heuristic_noise.py index 20a4c5b0..78879a5e 100644 --- a/test/qasm2/passes/test_heuristic_noise.py +++ b/test/qasm2/passes/test_heuristic_noise.py @@ -1,4 +1,3 @@ -import pytest from kirin import ir, types from kirin.dialects import func, ilist from kirin.dialects.py import constant @@ -256,7 +255,6 @@ def test_parallel_cz_gate_noise(): assert_nodes(block, expected_block) -@pytest.mark.xfail def test_global_noise(): @qasm2.extended diff --git a/test/qasm2/passes/test_parallel_to_global.py b/test/qasm2/passes/test_parallel_to_global.py index 93fbac7f..c72533b8 100644 --- a/test/qasm2/passes/test_parallel_to_global.py +++ b/test/qasm2/passes/test_parallel_to_global.py @@ -1,10 +1,7 @@ -import pytest - from bloqade import qasm2 from bloqade.qasm2.passes.parallel import ParallelToGlobal -@pytest.mark.xfail def test_basic_rewrite(): @qasm2.extended @@ -32,7 +29,6 @@ def main(): ) -@pytest.mark.xfail def test_if_rewrite(): @qasm2.extended def main(): @@ -67,7 +63,6 @@ def main(): ) -@pytest.mark.xfail def test_should_not_be_rewritten(): @qasm2.extended @@ -93,7 +88,6 @@ def main(): ) -@pytest.mark.xfail def test_multiple_registers(): @qasm2.extended def main(): @@ -126,7 +120,6 @@ def main(): ) -@pytest.mark.xfail def test_reverse_order(): @qasm2.extended def main(): diff --git a/test/qasm2/passes/test_parallel_to_uop.py b/test/qasm2/passes/test_parallel_to_uop.py index 7484542e..c3e2c59e 100644 --- a/test/qasm2/passes/test_parallel_to_uop.py +++ b/test/qasm2/passes/test_parallel_to_uop.py @@ -1,6 +1,5 @@ from typing import List -import pytest from kirin import ir, types from kirin.dialects import py, func @@ -17,7 +16,6 @@ def as_float(value: float): return py.constant.Constant(value=value) -@pytest.mark.xfail def test_cz_rewrite(): @qasm2.extended diff --git a/test/qasm2/passes/test_uop_to_parallel.py b/test/qasm2/passes/test_uop_to_parallel.py index 7e1f3bdc..29231d39 100644 --- a/test/qasm2/passes/test_uop_to_parallel.py +++ b/test/qasm2/passes/test_uop_to_parallel.py @@ -1,5 +1,3 @@ -import pytest - from bloqade import qasm2 from bloqade.qasm2 import glob from bloqade.analysis import address @@ -7,7 +5,6 @@ from bloqade.qasm2.rewrite import SimpleOptimalMergePolicy -@pytest.mark.xfail def test_one(): @qasm2.gate @@ -50,7 +47,6 @@ def test(): ) -@pytest.mark.xfail def test_two(): @qasm2.extended @@ -89,7 +85,6 @@ def test(): _, _ = address.AddressAnalysis(test.dialects).run(test) -@pytest.mark.xfail def test_three(): @qasm2.extended diff --git a/test/qasm2/test_count.py b/test/qasm2/test_count.py index df5cef31..32aed54b 100644 --- a/test/qasm2/test_count.py +++ b/test/qasm2/test_count.py @@ -1,4 +1,3 @@ -import pytest from kirin import passes from kirin.dialects import py, ilist @@ -16,7 +15,6 @@ fold = passes.Fold(qasm2.main.add(py.tuple).add(ilist)) -@pytest.mark.xfail def test_fixed_count(): @qasm2.main def fixed_count(): @@ -36,7 +34,6 @@ def fixed_count(): assert address.qubit_count == 7 -@pytest.mark.xfail def test_multiple_return_only_reg(): @qasm2.main.add(py.tuple) @@ -54,7 +51,6 @@ def tuple_count(): assert isinstance(ret.data[1], AddressReg) and ret.data[1].data == range(3, 7) -@pytest.mark.xfail def test_dynamic_address(): @qasm2.main def dynamic_address(): @@ -64,14 +60,16 @@ def dynamic_address(): qasm2.measure(ra[0], ca[0]) qasm2.measure(rb[1], ca[1]) if ca[0] == ca[1]: - return ra + ret = ra else: - return rb + ret = rb + + return ret # dynamic_address.code.print() dynamic_address.print() fold(dynamic_address) - frame, result = address.run_analysis(dynamic_address) + frame, result = address.run(dynamic_address) dynamic_address.print(analysis=frame.entries) assert isinstance(result, Unknown) @@ -92,7 +90,6 @@ def dynamic_address(): # assert isinstance(result, ConstResult) -@pytest.mark.xfail def test_multi_return(): @qasm2.main.add(py.tuple) def multi_return_cnt(): @@ -110,7 +107,6 @@ def multi_return_cnt(): assert isinstance(result.data[2], AddressReg) -@pytest.mark.xfail def test_list(): @qasm2.main.add(ilist) def list_count_analy(): @@ -120,12 +116,11 @@ def list_count_analy(): return f list_count_analy.code.print() - _, ret = address.run_analysis(list_count_analy) + _, ret = address.run(list_count_analy) assert ret == AddressReg(data=(0, 1, 3)) -@pytest.mark.xfail def test_tuple_qubits(): @qasm2.main.add(py.tuple) def list_count_analy2(): @@ -136,7 +131,7 @@ def list_count_analy2(): return f list_count_analy2.code.print() - _, ret = address.run_analysis(list_count_analy2) + _, ret = address.run(list_count_analy2) assert isinstance(ret, PartialTuple) assert isinstance(ret.data[0], AddressQubit) and ret.data[0].data == 0 assert isinstance(ret.data[1], AddressQubit) and ret.data[1].data == 1 @@ -166,7 +161,6 @@ def list_count_analy2(): # assert isinstance(result.data[5], AddressQubit) and result.data[5].data == 6 -@pytest.mark.xfail def test_alias(): @qasm2.main diff --git a/test/qasm2/test_lowering.py b/test/qasm2/test_lowering.py index 6eee1509..617d7154 100644 --- a/test/qasm2/test_lowering.py +++ b/test/qasm2/test_lowering.py @@ -3,7 +3,6 @@ import tempfile import textwrap -import pytest from kirin import ir, types from kirin.dialects import func @@ -26,14 +25,12 @@ ) -@pytest.mark.xfail def test_run_lowering(): ast = qasm2.parse.loads(lines) code = QASM2(qasm2.main).run(ast) code.print() -@pytest.mark.xfail def test_loadfile(): with tempfile.TemporaryDirectory() as tmp_dir: @@ -44,7 +41,6 @@ def test_loadfile(): qasm2.loadfile(file) -@pytest.mark.xfail def test_negative_lowering(): mwe = """ @@ -84,7 +80,6 @@ def test_negative_lowering(): assert entry.code.is_structurally_equal(code) -@pytest.mark.xfail def test_gate(): qasm2_prog = textwrap.dedent( """ @@ -113,7 +108,6 @@ def test_gate(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) -@pytest.mark.xfail def test_gate_with_params(): qasm2_prog = textwrap.dedent( """ @@ -144,7 +138,6 @@ def test_gate_with_params(): assert math.isclose(abs(ket[3]) ** 2, 0.5, abs_tol=1e-6) -@pytest.mark.xfail def test_if_lowering(): qasm2_prog = textwrap.dedent( diff --git a/test/qasm2/test_native.py b/test/qasm2/test_native.py index 15e6ebcb..fdfaf64d 100644 --- a/test/qasm2/test_native.py +++ b/test/qasm2/test_native.py @@ -3,7 +3,6 @@ import cirq import numpy as np -import pytest import cirq.testing import cirq.contrib.qasm_import as qasm_import import cirq.circuits.qasm_output as qasm_output @@ -158,7 +157,6 @@ def kernel(): assert new_qasm2.count("\n") > prog.count("\n") -@pytest.mark.xfail def test_ccx_rewrite(): @qasm2.extended From 681cb8b827629a285137f8bbafd480f88f172428 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 3 Nov 2025 15:10:43 +0100 Subject: [PATCH 13/25] Fix get_qubit_id test --- test/squin/test_qubit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/squin/test_qubit.py b/test/squin/test_qubit.py index bc16f4cc..001b9250 100644 --- a/test/squin/test_qubit.py +++ b/test/squin/test_qubit.py @@ -19,7 +19,7 @@ def main(): main.print() assert main.return_type.is_subseteq(types.Int) - @squin.kernel + @squin.kernel(fold=False) def main2(): q = squin.qalloc(2) @@ -34,10 +34,10 @@ def main2(): if m1_id != 0: # do something that errors - squin.x(q[4]) + q[0] + 1 if m2_id != 1: - squin.x(q[4]) + q[0] + 1 return squin.broadcast.measure(q) From 5b6df40286849382ad31abc4bbefc48ea1ff4471 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 3 Nov 2025 15:35:03 +0100 Subject: [PATCH 14/25] Cleaner self arg in qasm loading --- src/bloqade/qasm2/_qasm_loading.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/bloqade/qasm2/_qasm_loading.py b/src/bloqade/qasm2/_qasm_loading.py index f425cd08..9832c4ee 100644 --- a/src/bloqade/qasm2/_qasm_loading.py +++ b/src/bloqade/qasm2/_qasm_loading.py @@ -4,6 +4,7 @@ from typing import Any from kirin import ir, lowering +from kirin.types import MethodType from kirin.dialects import func from . import parse @@ -82,8 +83,10 @@ def loads( body=body, ) - self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument - body.blocks[0]._args = (self_arg,) + # self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument + + body.blocks[0].args.append_from(MethodType, kernel_name + "_self") + # body.blocks[0]._args = (self_arg,) mt = ir.Method( sym_name=kernel_name, From 16bd2725935e757062066c1544324e82cbafc80c Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 3 Nov 2025 17:17:57 +0100 Subject: [PATCH 15/25] Fix stim tests by removing newlines --- test/stim/passes/stim_reference_programs/debug/debug.stim | 1 - .../passes/stim_reference_programs/qubit/pick_if_else.stim | 1 - test/stim/passes/stim_reference_programs/qubit/qubit.stim | 1 - .../stim_reference_programs/qubit/qubit_broadcast.stim | 1 - .../passes/stim_reference_programs/qubit/qubit_loss.stim | 1 - .../passes/stim_reference_programs/qubit/qubit_reset.stim | 1 - .../passes/stim_reference_programs/qubit/rep_code.stim | 1 - .../passes/stim_reference_programs/qubit/u3_gates.stim | 1 - .../stim_reference_programs/qubit/u3_to_clifford.stim | 1 - test/stim/passes/test_squin_noise_to_stim.py | 7 ++++--- 10 files changed, 4 insertions(+), 12 deletions(-) diff --git a/test/stim/passes/stim_reference_programs/debug/debug.stim b/test/stim/passes/stim_reference_programs/debug/debug.stim index 7479e928..d4d0923c 100644 --- a/test/stim/passes/stim_reference_programs/debug/debug.stim +++ b/test/stim/passes/stim_reference_programs/debug/debug.stim @@ -1,2 +1 @@ - # debug message diff --git a/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim b/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim index 05e9c279..f0285a3d 100644 --- a/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim +++ b/test/stim/passes/stim_reference_programs/qubit/pick_if_else.stim @@ -1,2 +1 @@ - H 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit.stim b/test/stim/passes/stim_reference_programs/qubit/qubit.stim index 17873714..51cc860c 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit.stim @@ -1,4 +1,3 @@ - H 0 1 X 0 CX 0 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim index 708cb2a0..e033406d 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_broadcast.stim @@ -1,3 +1,2 @@ - H 0 1 2 3 MZ(0.00000000) 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim index 4dee3da2..62b6bf1b 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_loss.stim @@ -1,4 +1,3 @@ - H 0 1 2 3 4 I_ERROR[loss](0.10000000) 3 I_ERROR[loss](0.05000000) 0 1 2 3 4 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim b/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim index 958e0cfe..bcecaca7 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit_reset.stim @@ -1,3 +1,2 @@ - RZ 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/rep_code.stim b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim index 9105cf43..171dd872 100644 --- a/test/stim/passes/stim_reference_programs/qubit/rep_code.stim +++ b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim @@ -1,4 +1,3 @@ - RZ 0 1 2 3 4 CX 0 1 2 3 CX 2 1 4 3 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim index 4764563d..cfc100b0 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim @@ -1,4 +1,3 @@ - Z 0 SQRT_X_DAG 0 SQRT_X_DAG 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim index 6eb044c6..17a51d33 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim @@ -1,3 +1,2 @@ - H 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py index dc3346d5..b330464d 100644 --- a/test/stim/passes/test_squin_noise_to_stim.py +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -252,10 +252,11 @@ def test(): SquinToStimPass(test.dialects)(test) - emit = EmitStimMain(correlation_identifier_offset=10) + buf = io.StringIO() + emit = EmitStimMain(test.dialects, correlation_identifier_offset=10, io=buf) emit.initialize() - emit.run(mt=test, args=()) - stim_str = emit.get_output().strip() + emit.run(test) + stim_str = buf.getvalue().strip() assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3" From 8b92c5c26478a23d421c73433a5e828d25d4f661 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 10:13:42 +0100 Subject: [PATCH 16/25] Fix run_lattice in address analysis --- src/bloqade/analysis/address/analysis.py | 19 ++++++------------- src/bloqade/analysis/address/impls.py | 5 +++-- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 973e2d90..921f4b3b 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -96,7 +96,8 @@ def run_lattice( self, callee: Address, inputs: tuple[Address, ...], - kwargs: tuple[str, ...], + keys: tuple[str, ...], + kwargs: tuple[Address, ...], ) -> Address: """Run a callable lattice element with the given inputs and keyword arguments. @@ -111,19 +112,16 @@ def run_lattice( """ match callee: - case PartialLambda(code=code, argnames=argnames): - _, ret = self.run_callable( - code, (callee,) + self.permute_values(argnames, inputs, kwargs) + case PartialLambda(code=code): + _, ret = self.call( + code, callee, *inputs, **{k: v for k, v in zip(keys, kwargs)} ) - return ret case ConstResult(const.Value(ir.Method() as method)): _, ret = self.call( method.code, self.method_self(method), *inputs, - # **kwargs, - # **{k: v for k, v in zip(kwargs, frame.get_values(stmt.kwargs))}, - # self.permute_values(method.arg_names, inputs, kwargs), + **{k: v for k, v in zip(keys, kwargs)}, ) return ret case _: @@ -148,10 +146,5 @@ def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement): return tuple(Address.from_type(result.type) for result in node.results) - # def run(self, method: ir.Method, *args: Address, **kwargs): - # # NOTE: we do not support dynamic calls here, thus no need to propagate method object - # self_mt = ConstResult(const.Value(method)) - # return self.call(method.code, self_mt, *args, **kwargs) - def method_self(self, method: ir.Method) -> Address: return ConstResult(const.Value(method)) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index 9df7d56f..d7932986 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -97,7 +97,7 @@ def map_( results = [] for ele in values: - ret = interp_.run_lattice(fn, (ele,), ()) + ret = interp_.run_lattice(fn, (ele,), (), ()) results.append(ret) if isinstance(stmt, ilist.Map): @@ -216,7 +216,8 @@ def call( result = interp_.run_lattice( frame.get(stmt.callee), frame.get_values(stmt.inputs), - stmt.kwargs, + stmt.keys, + frame.get_values(stmt.kwargs), ) return (result,) From 141ff59a175b3191d9e4182afd15ca13e4de3120 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 10:28:29 +0100 Subject: [PATCH 17/25] Remove old code --- src/bloqade/analysis/measure_id/analysis.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/bloqade/analysis/measure_id/analysis.py b/src/bloqade/analysis/measure_id/analysis.py index 1ebf0387..8b65b2f3 100644 --- a/src/bloqade/analysis/measure_id/analysis.py +++ b/src/bloqade/analysis/measure_id/analysis.py @@ -33,10 +33,6 @@ def eval_fallback( ) -> tuple[MeasureId, ...]: return tuple(NotMeasureId() for _ in node.results) - # def run(self, method: ir.Method, *args: MeasureId, **kwargs): - # # NOTE: we do not support dynamic calls here, thus no need to propagate method object - # return self.call(method.code, self.lattice.bottom(), *args, **kwargs) - # Xiu-zhe (Roger) Luo came up with this in the address analysis, # reused here for convenience (now modified to be a bit more graceful) # TODO: Remove this function once upgrade to kirin 0.18 happens, @@ -45,7 +41,7 @@ def eval_fallback( T = TypeVar("T") def get_const_value( - self, input_type: type[T], value: ir.SSAValue + self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue ) -> type[T] | None: if isinstance(hint := value.hints.get("const"), const.Value): data = hint.data From fc61d023ffe456556af7b14b728119d7ed898f79 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 10:28:39 +0100 Subject: [PATCH 18/25] Clean up cirq emit --- src/bloqade/cirq_utils/emit/base.py | 44 +++++++---------------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 199e49dc..81e9152a 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -189,39 +189,8 @@ def initialize_frame( node, has_parent_access=has_parent_access, qubits=self.qubits ) - def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]): - return self.call(method, *args) - - def run_callable_region( - self, - frame: EmitCirqFrame, - code: ir.Statement, - region: ir.Region, - args: tuple, - ): - if len(region.blocks) > 0: - block_args = list(region.blocks[0].args) - # NOTE: skip self arg - frame.set_values(block_args[1:], args) - - results = self.frame_eval(frame, code) - if isinstance(results, tuple): - if len(results) == 0: - return self.void - elif len(results) == 1: - return results[0] - raise interp.InterpreterError(f"Unexpected results {results}") - - def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit: - for stmt in block.stmts: - result = self.frame_eval(frame, stmt) - if isinstance(result, tuple): - frame.set_values(stmt.results, result) - - return self.circuit - def reset(self): - pass + self.circuit = cirq.Circuit() def eval_fallback(self, frame: EmitCirqFrame, node: ir.Statement) -> tuple: return tuple(None for _ in range(len(node.results))) @@ -261,8 +230,15 @@ def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem): return () @interp.impl(py.Constant) - def emit_constant(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant): - return (stmt.value.data,) # pyright: ignore[reportAttributeAccessIssue] + def emit_constant( + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant[ir.PyAttr] + ): + if not isinstance(stmt.value, ir.PyAttr): + raise interp.exceptions.InterpreterError( + "Cannot lower constant without concrete data!" + ) + + return (stmt.value.data,) @ilist.dialect.register(key="emit.cirq") From 6b20ec025a6471fe760abd2227b720757f9dcf79 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 10:35:33 +0100 Subject: [PATCH 19/25] Revert change in emit constant --- src/bloqade/cirq_utils/emit/base.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 81e9152a..4c831d81 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -230,15 +230,8 @@ def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem): return () @interp.impl(py.Constant) - def emit_constant( - self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant[ir.PyAttr] - ): - if not isinstance(stmt.value, ir.PyAttr): - raise interp.exceptions.InterpreterError( - "Cannot lower constant without concrete data!" - ) - - return (stmt.value.data,) + def emit_constant(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant): + return (stmt.value.data,) # pyright: ignore[reportAttributeAccessIssue] @ilist.dialect.register(key="emit.cirq") From 17e8ffb3b117a416f15c4f54705865e177843ba9 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 10:35:46 +0100 Subject: [PATCH 20/25] Clean up cirq lowering a bit --- src/bloqade/cirq_utils/lowering.py | 4 +--- test/cirq_utils/test_cirq_to_squin.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 7d81b45b..fc6d895a 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -99,7 +99,7 @@ def main(): ``` """ - target = Squin(dialects=dialects, circuit=circuit) + target = Squin(dialects, circuit) body = target.run( circuit, source=str(circuit), # TODO: proper source string @@ -144,8 +144,6 @@ def main(): ) mt = ir.Method( - mod=None, - py_func=None, sym_name=kernel_name, arg_names=arg_names, dialects=dialects, diff --git a/test/cirq_utils/test_cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py index 101cd3be..479ba051 100644 --- a/test/cirq_utils/test_cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -250,9 +250,9 @@ def test_nesting_lowered_circuit(): @squin.kernel def main(): qreg = get_entangled_qubits() - qreg2 = squin.squin.qalloc(1) + qreg2 = squin.qalloc(1) entangle_qubits([qreg[1], qreg2[0]]) - return squin.qubit.measure(qreg2) + return squin.broadcast.measure(qreg2) # if you get up to here, the validation works main.print() From 636083b931ce4c67658f9653db90e4e403749d9b Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 10:48:33 +0100 Subject: [PATCH 21/25] Update kernel validation --- .../analysis/logical_validation/analysis.py | 6 +++--- src/bloqade/validation/analysis/analysis.py | 14 +++++++------- src/bloqade/validation/kernel_validation.py | 18 ++++++++++++++++-- test/gemini/test_logical_validation.py | 10 ++++------ 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index 14a03cbf..4cf2d7a2 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -9,9 +9,9 @@ class GeminiLogicalValidationAnalysis(ValidationAnalysis): first_gate = True - def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): - if isinstance(stmt, squin.gate.stmts.Gate): + def eval_fallback(self, frame: ValidationFrame, node: ir.Statement): + if isinstance(node, squin.gate.stmts.Gate): # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here self.first_gate = False - return super().eval_stmt_fallback(frame, stmt) + return super().eval_fallback(frame, node) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 323cbd40..19220b73 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -28,14 +28,14 @@ class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC): lattice = ErrorType - def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): - return self.run_callable(method.code, (self.lattice.top(),) + args) - - def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): + def eval_fallback(self, frame: ValidationFrame, node: ir.Statement): # NOTE: default to no errors - return tuple(self.lattice.top() for _ in stmt.results) + return tuple(self.lattice.top() for _ in node.results) def initialize_frame( - self, code: ir.Statement, *, has_parent_access: bool = False + self, node: ir.Statement, *, has_parent_access: bool = False ) -> ValidationFrame: - return ValidationFrame(code, has_parent_access=has_parent_access) + return ValidationFrame(node, has_parent_access=has_parent_access) + + def method_self(self, method: ir.Method) -> ErrorType: + return self.lattice.top() diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py index 84159352..d3f82774 100644 --- a/src/bloqade/validation/kernel_validation.py +++ b/src/bloqade/validation/kernel_validation.py @@ -48,9 +48,23 @@ class KernelValidation: validation_analysis_cls: type[ValidationAnalysis] """The analysis that you want to run in order to validate the kernel.""" - def run(self, mt: ir.Method, **kwargs) -> None: + def run(self, mt: ir.Method, no_raise: bool = True) -> None: + """Run the kernel validation analysis and raise any errors found. + + Args: + mt (ir.Method): The method to validate + no_raise (bool): Whether or not to raise errors when running the analysis. + This is only to make sure the analysis works. Errors found during + the analysis will be raised regardless of this setting. Defaults to `True`. + + """ + validation_analysis = self.validation_analysis_cls(mt.dialects) - validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) + + if no_raise: + validation_frame, _ = validation_analysis.run_no_raise(mt) + else: + validation_frame, _ = validation_analysis.run(mt) errors = validation_frame.errors diff --git a/test/gemini/test_logical_validation.py b/test/gemini/test_logical_validation.py index ce8e1a34..046b2ac4 100644 --- a/test/gemini/test_logical_validation.py +++ b/test/gemini/test_logical_validation.py @@ -30,16 +30,14 @@ def main(): if m2: squin.y(q[2]) - frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_analysis( - main, no_raise=False - ) + frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_no_raise(main) main.print(analysis=frame.entries) validator = KernelValidation(GeminiLogicalValidationAnalysis) with pytest.raises(ValidationErrorGroup): - validator.run(main) + validator.run(main, no_raise=False) def test_for_loop(): @@ -104,8 +102,8 @@ def invalid(): squin.cx(q[0], q[1]) squin.u3(0.123, 0.253, 1.2, q[0]) - frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_analysis( - invalid, no_raise=False + frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise( + invalid ) invalid.print(analysis=frame.entries) From d86cad350a858b1ef6d56fa6e7f219a66c0465ff Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 11:00:09 +0100 Subject: [PATCH 22/25] Remove unneeded method from QASM2 emit --- src/bloqade/qasm2/emit/base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/bloqade/qasm2/emit/base.py b/src/bloqade/qasm2/emit/base.py index 4f7fba1d..4fac3217 100644 --- a/src/bloqade/qasm2/emit/base.py +++ b/src/bloqade/qasm2/emit/base.py @@ -45,11 +45,6 @@ def initialize_frame( ) -> EmitQASM2Frame[StmtType]: return EmitQASM2Frame(node, has_parent_access=has_parent_access) - def run_method( - self, method: ir.Method, args: tuple[ast.Node | None, ...] - ) -> tuple[EmitQASM2Frame[StmtType], ast.Node | None]: - return self.call(method, *args) - def emit_block(self, frame: EmitQASM2Frame, block: ir.Block) -> ast.Node | None: for stmt in block.stmts: result = self.frame_eval(frame, stmt) From ef1498782808679699228b859a7d2e8968559bf2 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 11:05:07 +0100 Subject: [PATCH 23/25] Remove some commented code --- src/bloqade/qasm2/_qasm_loading.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/bloqade/qasm2/_qasm_loading.py b/src/bloqade/qasm2/_qasm_loading.py index 9832c4ee..57ee5815 100644 --- a/src/bloqade/qasm2/_qasm_loading.py +++ b/src/bloqade/qasm2/_qasm_loading.py @@ -83,10 +83,7 @@ def loads( body=body, ) - # self_arg = ir.BlockArgument(body.blocks[0], 0) # Self argument - body.blocks[0].args.append_from(MethodType, kernel_name + "_self") - # body.blocks[0]._args = (self_arg,) mt = ir.Method( sym_name=kernel_name, From 8191481889a4a7dede948f2a41004f8983bbc08d Mon Sep 17 00:00:00 2001 From: kaihsin Date: Wed, 5 Nov 2025 15:49:35 -0500 Subject: [PATCH 24/25] fix dialect group for codegen test --- .../qubit/for_loop_nontrivial_index.stim | 1 - .../passes/stim_reference_programs/qubit/nested_for_loop.stim | 1 - test/stim/passes/stim_reference_programs/qubit/nested_list.stim | 1 - .../stim_reference_programs/qubit/non_pure_loop_iterator.stim | 1 - test/stim/passes/test_squin_noise_to_stim.py | 2 +- 5 files changed, 1 insertion(+), 5 deletions(-) diff --git a/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim b/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim index 5081b891..7dc34014 100644 --- a/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim +++ b/test/stim/passes/stim_reference_programs/qubit/for_loop_nontrivial_index.stim @@ -1,4 +1,3 @@ - H 0 CX 0 1 CX 1 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim b/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim index 26adce10..f3c3c462 100644 --- a/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim +++ b/test/stim/passes/stim_reference_programs/qubit/nested_for_loop.stim @@ -1,4 +1,3 @@ - H 0 CX 0 2 CX 1 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/nested_list.stim b/test/stim/passes/stim_reference_programs/qubit/nested_list.stim index 37920f35..0c58c050 100644 --- a/test/stim/passes/stim_reference_programs/qubit/nested_list.stim +++ b/test/stim/passes/stim_reference_programs/qubit/nested_list.stim @@ -1,3 +1,2 @@ - H 0 H 2 diff --git a/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim b/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim index 9088e275..ccb35ef0 100644 --- a/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim +++ b/test/stim/passes/stim_reference_programs/qubit/non_pure_loop_iterator.stim @@ -1,4 +1,3 @@ - MZ(0.00000000) 0 1 2 3 4 X 0 X 1 diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py index b330464d..04df2837 100644 --- a/test/stim/passes/test_squin_noise_to_stim.py +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -253,7 +253,7 @@ def test(): SquinToStimPass(test.dialects)(test) buf = io.StringIO() - emit = EmitStimMain(test.dialects, correlation_identifier_offset=10, io=buf) + emit = EmitStimMain(stim.main, correlation_identifier_offset=10, io=buf) emit.initialize() emit.run(test) stim_str = buf.getvalue().strip() From e00e129dfb525477a66d7cc4445355f950a25ec1 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 6 Nov 2025 09:15:56 +0100 Subject: [PATCH 25/25] Remove unneeded run_no_raise method --- src/bloqade/analysis/fidelity/analysis.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/bloqade/analysis/fidelity/analysis.py b/src/bloqade/analysis/fidelity/analysis.py index 503cc8e9..815b5725 100644 --- a/src/bloqade/analysis/fidelity/analysis.py +++ b/src/bloqade/analysis/fidelity/analysis.py @@ -67,19 +67,12 @@ def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): return def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]: - self._run_address_analysis(method, no_raise=False) + self._run_address_analysis(method) return super().run(method, *args, **kwargs) - def run_no_raise(self, method: ir.Method, *args: Any, **kwargs: Any): - self._run_address_analysis(method, no_raise=True) - return super().run_no_raise(method, *args, **kwargs) - - def _run_address_analysis(self, method: ir.Method, no_raise: bool): + def _run_address_analysis(self, method: ir.Method): addr_analysis = AddressAnalysis(self.dialects) - if no_raise: - addr_frame, _ = addr_analysis.run_no_raise(method=method) - else: - addr_frame, _ = addr_analysis.run(method=method) + addr_frame, _ = addr_analysis.run(method=method) self.addr_frame = addr_frame # NOTE: make sure we have as many probabilities as we have addresses