Skip to content

Commit 709f662

Browse files
committed
Remove correlated error nonce from stim IR. Instead, stim emitter now adds tags
1 parent 19dd7b3 commit 709f662

File tree

9 files changed

+64
-30
lines changed

9 files changed

+64
-30
lines changed

src/bloqade/stim/dialects/noise/emit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,11 @@ def non_stim_corr_error(
9393
prob: tuple[str, ...] = frame.get_values(stmt.probs)
9494
prob_str: str = ", ".join(prob)
9595

96-
res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets)
96+
res = (
97+
f"I_ERROR[{stmt.name}:{emit.correlation_identifier_offset}]({prob_str}) "
98+
+ " ".join(targets)
99+
)
100+
emit.correlation_identifier_offset += 1
97101
emit.writeln(frame, res)
98102

99103
return ()

src/bloqade/stim/dialects/noise/stmts.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
from kirin import ir, types, lowering
42
from kirin.decl import info, statement
53

@@ -91,8 +89,6 @@ class NonStimError(ir.Statement):
9189
class NonStimCorrelatedError(ir.Statement):
9290
name = "NonStimCorrelatedError"
9391
traits = frozenset({lowering.FromPythonCall()})
94-
# nonce must be a unique value, otherwise stim might merge two correlated errors
95-
nonce: int = info.attribute(default_factory=lambda: random.getrandbits(32))
9692
probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
9793
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
9894

src/bloqade/stim/emit/stim_str.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class EmitStimMain(EmitStr):
2020
keys = ["emit.stim"]
2121
dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
2222
file: StringIO = field(default_factory=StringIO)
23+
correlation_identifier_offset: int = 0
2324

2425
def initialize(self):
2526
super().initialize()

src/bloqade/stim/parse/lowering.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -627,10 +627,13 @@ def visit_I_ERROR(
627627
# Parse tag
628628
tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
629629
nonstim_name = tag_parts[0]
630-
nonce = 0
631630
if len(tag_parts) == 2:
631+
# This should be a correlated error of the form, e.g.,
632+
# I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
633+
# The identifier is a unique number that prevents stim from merging
634+
# correlated errors. We discard the identifier, but verify it is an integer.
632635
try:
633-
nonce = int(tag_parts[1])
636+
_ = int(tag_parts[1])
634637
except ValueError:
635638
# String was not an integer
636639
if self.error_unknown_nonstim:
@@ -643,22 +646,14 @@ def visit_I_ERROR(
643646
f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
644647
)
645648
statement_cls = self.nonstim_noise_ops.get(nonstim_name)
649+
stmt = None
646650
if statement_cls is not None:
647-
if issubclass(statement_cls, noise.NonStimCorrelatedError):
648-
stmt = statement_cls(
649-
nonce=nonce,
650-
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
651-
targets=self._get_multiple_qubit_or_rec_ssa(
652-
state, node, node.targets_copy()
653-
),
654-
)
655-
else:
656-
stmt = statement_cls(
657-
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
658-
targets=self._get_multiple_qubit_or_rec_ssa(
659-
state, node, node.targets_copy()
660-
),
661-
)
651+
stmt = statement_cls(
652+
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
653+
targets=self._get_multiple_qubit_or_rec_ssa(
654+
state, node, node.targets_copy()
655+
),
656+
)
662657
return stmt
663658

664659
def visit_CircuitInstruction(

src/bloqade/stim/rewrite/squin_noise.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
@dataclass
1717
class SquinNoiseToStim(RewriteRule):
18-
_correlated_loss_counter: int = 0
1918

2019
def rewrite_Statement(self, node: Statement) -> RewriteResult:
2120
match node:
@@ -135,9 +134,7 @@ def rewrite_CorrelatedQubitLoss(
135134
stim_stmt = stim_noise.CorrelatedQubitLoss(
136135
targets=qubit_idx_ssas,
137136
probs=(stmt.p,),
138-
nonce=self._correlated_loss_counter,
139137
)
140-
self._correlated_loss_counter += 1
141138

142139
return stim_stmt
143140

test/stim/dialects/stim/emit/test_stim_noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_qubit_loss():
4545
def test_correlated_qubit_loss():
4646
@stim.main
4747
def test_correlated_qubit_loss():
48-
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
48+
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))
4949

5050
out = codegen(test_correlated_qubit_loss)
51-
assert out.strip() == "I_ERROR[correlated_loss:3](0.10000000) 0 1 2"
51+
assert out.strip() == "I_ERROR[correlated_loss:0](0.10000000) 0 1 2"

test/stim/parse/test_parse_custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_parse_trivial_correlated():
2727

2828
# test roundtrip
2929
out = codegen(mt)
30-
assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:3](0.20000000, 0.30000000) 5 0 1 2"
30+
assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:0](0.20000000, 0.30000000) 5 0 1 2"
3131

3232

3333
def test_qubit_loss():

test/stim/passes/test_squin_noise_to_stim.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from bloqade.squin import noise, kernel
1111
from bloqade.types import Qubit, QubitType
1212
from bloqade.stim.emit import EmitStimMain
13+
from bloqade.stim.parse import loads
1314
from bloqade.stim.passes import SquinToStimPass, flatten
1415
from bloqade.stim.rewrite import SquinNoiseToStim
1516
from bloqade.squin.rewrite import WrapAddressAnalysis
17+
from bloqade.stim.dialects import noise as stim_noise
1618
from bloqade.analysis.address import AddressAnalysis
1719

1820

@@ -242,6 +244,45 @@ def test():
242244
assert codegen(test) == expected
243245

244246

247+
def test_correlated_qubit_loss_codegen_roundtrip():
248+
249+
@kernel
250+
def test():
251+
q = sq.qalloc(4)
252+
sq.correlated_qubit_loss(0.1, qubits=q[:2])
253+
sq.qubit_loss(0.2, qubit=q[2])
254+
sq.broadcast.qubit_loss(0.3, qubits=q)
255+
sq.broadcast.correlated_qubit_loss(0.1, qubits=[q[:2], q[2:]])
256+
257+
SquinToStimPass(test.dialects)(test)
258+
stim_str = codegen(test)
259+
260+
mt = loads(
261+
stim_str,
262+
nonstim_noise_ops={
263+
"loss": stim_noise.QubitLoss,
264+
"correlated_loss": stim_noise.CorrelatedQubitLoss,
265+
},
266+
)
267+
assert codegen(mt) == stim_str
268+
269+
270+
def test_correlated_qubit_loss_codegen_with_offset():
271+
272+
@kernel
273+
def test():
274+
q = sq.qalloc(4)
275+
sq.correlated_qubit_loss(0.1, qubits=q)
276+
277+
SquinToStimPass(test.dialects)(test)
278+
279+
emit = EmitStimMain(correlation_identifier_offset=10)
280+
emit.initialize()
281+
emit.run(mt=test, args=())
282+
stim_str = emit.get_output().strip()
283+
assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3"
284+
285+
245286
def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement:
246287
return method.callable_region.blocks[0].stmts.at(idx)
247288

test/stim/wrapper/test_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,11 @@ def main_ry_wrap():
466466
def test_wrap_correlated_qubit_loss():
467467
@stim.main
468468
def main_correlated_qubit_loss():
469-
noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
469+
noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2))
470470

471471
@stim.main
472472
def main_correlated_qubit_loss_wrap():
473-
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
473+
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))
474474

475475
assert main_correlated_qubit_loss.callable_region.is_structurally_equal(
476476
main_correlated_qubit_loss_wrap.callable_region

0 commit comments

Comments
 (0)