diff --git a/src/bloqade/stim/dialects/noise/emit.py b/src/bloqade/stim/dialects/noise/emit.py index 8f3ebefad..41e57a06f 100644 --- a/src/bloqade/stim/dialects/noise/emit.py +++ b/src/bloqade/stim/dialects/noise/emit.py @@ -93,7 +93,11 @@ def non_stim_corr_error( prob: tuple[str, ...] = frame.get_values(stmt.probs) prob_str: str = ", ".join(prob) - res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets) + res = ( + f"I_ERROR[{stmt.name}:{emit.correlated_error_count}]({prob_str}) " + + " ".join(targets) + ) + emit.correlated_error_count += 1 emit.writeln(frame, res) return () diff --git a/src/bloqade/stim/dialects/noise/stmts.py b/src/bloqade/stim/dialects/noise/stmts.py index b336f901d..a63eafcd8 100644 --- a/src/bloqade/stim/dialects/noise/stmts.py +++ b/src/bloqade/stim/dialects/noise/stmts.py @@ -1,5 +1,3 @@ -import random - from kirin import ir, types, lowering from kirin.decl import info, statement @@ -91,8 +89,6 @@ class NonStimError(ir.Statement): class NonStimCorrelatedError(ir.Statement): name = "NonStimCorrelatedError" traits = frozenset({lowering.FromPythonCall()}) - # nonce must be a unique value, otherwise stim might merge two correlated errors - nonce: int = info.attribute(default_factory=lambda: random.getrandbits(32)) probs: tuple[ir.SSAValue, ...] = info.argument(types.Float) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) diff --git a/src/bloqade/stim/emit/stim_str.py b/src/bloqade/stim/emit/stim_str.py index 640d97a1f..0da37ff8a 100644 --- a/src/bloqade/stim/emit/stim_str.py +++ b/src/bloqade/stim/emit/stim_str.py @@ -20,11 +20,13 @@ class EmitStimMain(EmitStr): keys = ["emit.stim"] dialects: ir.DialectGroup = field(default_factory=_default_dialect_group) file: StringIO = field(default_factory=StringIO) + correlation_identifier_offset: int = 0 def initialize(self): super().initialize() self.file.truncate(0) self.file.seek(0) + self.correlated_error_count = self.correlation_identifier_offset return self def eval_stmt_fallback( diff --git a/src/bloqade/stim/parse/lowering.py b/src/bloqade/stim/parse/lowering.py index 84a1a4b7f..8efe93696 100644 --- a/src/bloqade/stim/parse/lowering.py +++ b/src/bloqade/stim/parse/lowering.py @@ -627,10 +627,13 @@ def visit_I_ERROR( # Parse tag tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1) nonstim_name = tag_parts[0] - nonce = 0 if len(tag_parts) == 2: + # This should be a correlated error of the form, e.g., + # I_ERROR[correlated_loss:](0.01) 0 1 2 + # The identifier is a unique number that prevents stim from merging + # correlated errors. We discard the identifier, but verify it is an integer. try: - nonce = int(tag_parts[1]) + _ = int(tag_parts[1]) except ValueError: # String was not an integer if self.error_unknown_nonstim: @@ -643,22 +646,14 @@ def visit_I_ERROR( f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})" ) statement_cls = self.nonstim_noise_ops.get(nonstim_name) + stmt = None if statement_cls is not None: - if issubclass(statement_cls, noise.NonStimCorrelatedError): - stmt = statement_cls( - nonce=nonce, - probs=self._get_float_args_ssa(state, node.gate_args_copy()), - targets=self._get_multiple_qubit_or_rec_ssa( - state, node, node.targets_copy() - ), - ) - else: - stmt = statement_cls( - probs=self._get_float_args_ssa(state, node.gate_args_copy()), - targets=self._get_multiple_qubit_or_rec_ssa( - state, node, node.targets_copy() - ), - ) + stmt = statement_cls( + probs=self._get_float_args_ssa(state, node.gate_args_copy()), + targets=self._get_multiple_qubit_or_rec_ssa( + state, node, node.targets_copy() + ), + ) return stmt def visit_CircuitInstruction( diff --git a/src/bloqade/stim/rewrite/squin_noise.py b/src/bloqade/stim/rewrite/squin_noise.py index ed27804d8..17140c19f 100644 --- a/src/bloqade/stim/rewrite/squin_noise.py +++ b/src/bloqade/stim/rewrite/squin_noise.py @@ -15,7 +15,6 @@ @dataclass class SquinNoiseToStim(RewriteRule): - _correlated_loss_counter: int = 0 def rewrite_Statement(self, node: Statement) -> RewriteResult: match node: @@ -135,9 +134,7 @@ def rewrite_CorrelatedQubitLoss( stim_stmt = stim_noise.CorrelatedQubitLoss( targets=qubit_idx_ssas, probs=(stmt.p,), - nonce=self._correlated_loss_counter, ) - self._correlated_loss_counter += 1 return stim_stmt diff --git a/test/stim/dialects/stim/emit/test_stim_noise.py b/test/stim/dialects/stim/emit/test_stim_noise.py index 4f1f19f42..773bc5c12 100644 --- a/test/stim/dialects/stim/emit/test_stim_noise.py +++ b/test/stim/dialects/stim/emit/test_stim_noise.py @@ -1,6 +1,16 @@ from bloqade import stim +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.parse import loads +from bloqade.stim.dialects import noise -from .base import codegen +emit = EmitStimMain() + + +def codegen(mt): + # method should not have any arguments! + emit.initialize() + emit.run(mt=mt, args=()) + return emit.get_output() def test_noise(): @@ -45,7 +55,43 @@ def test_qubit_loss(): def test_correlated_qubit_loss(): @stim.main def test_correlated_qubit_loss(): - stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3) + stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2)) out = codegen(test_correlated_qubit_loss) - assert out.strip() == "I_ERROR[correlated_loss:3](0.10000000) 0 1 2" + assert out.strip() == "I_ERROR[correlated_loss:0](0.10000000) 0 1 2" + + +def test_correlated_qubit_loss_multiple(): + + @stim.main + def test_correlated_qubit_loss_multiple(): + stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1)) + stim.correlated_qubit_loss(probs=(0.1,), targets=(2, 3)) + + for i in range(2): # repeat the test to ensure the identifier is reset each time + out = codegen(test_correlated_qubit_loss_multiple).strip() + print(out) + assert ( + out.strip() + == "I_ERROR[correlated_loss:0](0.10000000) 0 1\n" + + "I_ERROR[correlated_loss:1](0.10000000) 2 3" + ) + + +def test_correlated_qubit_codegen_roundtrip(): + @stim.main + def test(): + stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2)) + stim.qubit_loss(probs=(0.2,), targets=(2,)) + stim.correlated_qubit_loss(probs=(0.3,), targets=(3, 4)) + + stim_str = codegen(test) + + mt = loads( + stim_str, + nonstim_noise_ops={ + "loss": noise.QubitLoss, + "correlated_loss": noise.CorrelatedQubitLoss, + }, + ) + assert codegen(mt) == stim_str diff --git a/test/stim/parse/test_parse_custom.py b/test/stim/parse/test_parse_custom.py index 70d74d8bb..c1208a506 100644 --- a/test/stim/parse/test_parse_custom.py +++ b/test/stim/parse/test_parse_custom.py @@ -27,7 +27,7 @@ def test_parse_trivial_correlated(): # test roundtrip out = codegen(mt) - assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:3](0.20000000, 0.30000000) 5 0 1 2" + assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:0](0.20000000, 0.30000000) 5 0 1 2" def test_qubit_loss(): diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py index b3087ed64..5775acc57 100644 --- a/test/stim/passes/test_squin_noise_to_stim.py +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -242,6 +242,22 @@ def test(): assert codegen(test) == expected +def test_correlated_qubit_loss_codegen_with_offset(): + + @kernel + def test(): + q = sq.qalloc(4) + sq.correlated_qubit_loss(0.1, qubits=q) + + SquinToStimPass(test.dialects)(test) + + emit = EmitStimMain(correlation_identifier_offset=10) + emit.initialize() + emit.run(mt=test, args=()) + stim_str = emit.get_output().strip() + assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3" + + def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement: return method.callable_region.blocks[0].stmts.at(idx) diff --git a/test/stim/wrapper/test_wrapper.py b/test/stim/wrapper/test_wrapper.py index 53744a574..aeae367bb 100644 --- a/test/stim/wrapper/test_wrapper.py +++ b/test/stim/wrapper/test_wrapper.py @@ -466,11 +466,11 @@ def main_ry_wrap(): def test_wrap_correlated_qubit_loss(): @stim.main def main_correlated_qubit_loss(): - noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2), nonce=3) + noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2)) @stim.main def main_correlated_qubit_loss_wrap(): - stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3) + stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2)) assert main_correlated_qubit_loss.callable_region.is_structurally_equal( main_correlated_qubit_loss_wrap.callable_region