Skip to content

Commit 1e347d8

Browse files
authored
Remove correlated error nonce from stim IR (#560)
1 parent 19dd7b3 commit 1e347d8

File tree

9 files changed

+87
-31
lines changed

9 files changed

+87
-31
lines changed

src/bloqade/stim/dialects/noise/emit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,11 @@ def non_stim_corr_error(
9393
prob: tuple[str, ...] = frame.get_values(stmt.probs)
9494
prob_str: str = ", ".join(prob)
9595

96-
res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets)
96+
res = (
97+
f"I_ERROR[{stmt.name}:{emit.correlated_error_count}]({prob_str}) "
98+
+ " ".join(targets)
99+
)
100+
emit.correlated_error_count += 1
97101
emit.writeln(frame, res)
98102

99103
return ()

src/bloqade/stim/dialects/noise/stmts.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
from kirin import ir, types, lowering
42
from kirin.decl import info, statement
53

@@ -91,8 +89,6 @@ class NonStimError(ir.Statement):
9189
class NonStimCorrelatedError(ir.Statement):
9290
name = "NonStimCorrelatedError"
9391
traits = frozenset({lowering.FromPythonCall()})
94-
# nonce must be a unique value, otherwise stim might merge two correlated errors
95-
nonce: int = info.attribute(default_factory=lambda: random.getrandbits(32))
9692
probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
9793
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)
9894

src/bloqade/stim/emit/stim_str.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ class EmitStimMain(EmitStr):
2020
keys = ["emit.stim"]
2121
dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
2222
file: StringIO = field(default_factory=StringIO)
23+
correlation_identifier_offset: int = 0
2324

2425
def initialize(self):
2526
super().initialize()
2627
self.file.truncate(0)
2728
self.file.seek(0)
29+
self.correlated_error_count = self.correlation_identifier_offset
2830
return self
2931

3032
def eval_stmt_fallback(

src/bloqade/stim/parse/lowering.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -627,10 +627,13 @@ def visit_I_ERROR(
627627
# Parse tag
628628
tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
629629
nonstim_name = tag_parts[0]
630-
nonce = 0
631630
if len(tag_parts) == 2:
631+
# This should be a correlated error of the form, e.g.,
632+
# I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
633+
# The identifier is a unique number that prevents stim from merging
634+
# correlated errors. We discard the identifier, but verify it is an integer.
632635
try:
633-
nonce = int(tag_parts[1])
636+
_ = int(tag_parts[1])
634637
except ValueError:
635638
# String was not an integer
636639
if self.error_unknown_nonstim:
@@ -643,22 +646,14 @@ def visit_I_ERROR(
643646
f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
644647
)
645648
statement_cls = self.nonstim_noise_ops.get(nonstim_name)
649+
stmt = None
646650
if statement_cls is not None:
647-
if issubclass(statement_cls, noise.NonStimCorrelatedError):
648-
stmt = statement_cls(
649-
nonce=nonce,
650-
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
651-
targets=self._get_multiple_qubit_or_rec_ssa(
652-
state, node, node.targets_copy()
653-
),
654-
)
655-
else:
656-
stmt = statement_cls(
657-
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
658-
targets=self._get_multiple_qubit_or_rec_ssa(
659-
state, node, node.targets_copy()
660-
),
661-
)
651+
stmt = statement_cls(
652+
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
653+
targets=self._get_multiple_qubit_or_rec_ssa(
654+
state, node, node.targets_copy()
655+
),
656+
)
662657
return stmt
663658

664659
def visit_CircuitInstruction(

src/bloqade/stim/rewrite/squin_noise.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
@dataclass
1717
class SquinNoiseToStim(RewriteRule):
18-
_correlated_loss_counter: int = 0
1918

2019
def rewrite_Statement(self, node: Statement) -> RewriteResult:
2120
match node:
@@ -135,9 +134,7 @@ def rewrite_CorrelatedQubitLoss(
135134
stim_stmt = stim_noise.CorrelatedQubitLoss(
136135
targets=qubit_idx_ssas,
137136
probs=(stmt.p,),
138-
nonce=self._correlated_loss_counter,
139137
)
140-
self._correlated_loss_counter += 1
141138

142139
return stim_stmt
143140

test/stim/dialects/stim/emit/test_stim_noise.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
from bloqade import stim
2+
from bloqade.stim.emit import EmitStimMain
3+
from bloqade.stim.parse import loads
4+
from bloqade.stim.dialects import noise
25

3-
from .base import codegen
6+
emit = EmitStimMain()
7+
8+
9+
def codegen(mt):
10+
# method should not have any arguments!
11+
emit.initialize()
12+
emit.run(mt=mt, args=())
13+
return emit.get_output()
414

515

616
def test_noise():
@@ -45,7 +55,43 @@ def test_qubit_loss():
4555
def test_correlated_qubit_loss():
4656
@stim.main
4757
def test_correlated_qubit_loss():
48-
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
58+
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))
4959

5060
out = codegen(test_correlated_qubit_loss)
51-
assert out.strip() == "I_ERROR[correlated_loss:3](0.10000000) 0 1 2"
61+
assert out.strip() == "I_ERROR[correlated_loss:0](0.10000000) 0 1 2"
62+
63+
64+
def test_correlated_qubit_loss_multiple():
65+
66+
@stim.main
67+
def test_correlated_qubit_loss_multiple():
68+
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1))
69+
stim.correlated_qubit_loss(probs=(0.1,), targets=(2, 3))
70+
71+
for i in range(2): # repeat the test to ensure the identifier is reset each time
72+
out = codegen(test_correlated_qubit_loss_multiple).strip()
73+
print(out)
74+
assert (
75+
out.strip()
76+
== "I_ERROR[correlated_loss:0](0.10000000) 0 1\n"
77+
+ "I_ERROR[correlated_loss:1](0.10000000) 2 3"
78+
)
79+
80+
81+
def test_correlated_qubit_codegen_roundtrip():
82+
@stim.main
83+
def test():
84+
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))
85+
stim.qubit_loss(probs=(0.2,), targets=(2,))
86+
stim.correlated_qubit_loss(probs=(0.3,), targets=(3, 4))
87+
88+
stim_str = codegen(test)
89+
90+
mt = loads(
91+
stim_str,
92+
nonstim_noise_ops={
93+
"loss": noise.QubitLoss,
94+
"correlated_loss": noise.CorrelatedQubitLoss,
95+
},
96+
)
97+
assert codegen(mt) == stim_str

test/stim/parse/test_parse_custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_parse_trivial_correlated():
2727

2828
# test roundtrip
2929
out = codegen(mt)
30-
assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:3](0.20000000, 0.30000000) 5 0 1 2"
30+
assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:0](0.20000000, 0.30000000) 5 0 1 2"
3131

3232

3333
def test_qubit_loss():

test/stim/passes/test_squin_noise_to_stim.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,22 @@ def test():
242242
assert codegen(test) == expected
243243

244244

245+
def test_correlated_qubit_loss_codegen_with_offset():
246+
247+
@kernel
248+
def test():
249+
q = sq.qalloc(4)
250+
sq.correlated_qubit_loss(0.1, qubits=q)
251+
252+
SquinToStimPass(test.dialects)(test)
253+
254+
emit = EmitStimMain(correlation_identifier_offset=10)
255+
emit.initialize()
256+
emit.run(mt=test, args=())
257+
stim_str = emit.get_output().strip()
258+
assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3"
259+
260+
245261
def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement:
246262
return method.callable_region.blocks[0].stmts.at(idx)
247263

test/stim/wrapper/test_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,11 @@ def main_ry_wrap():
466466
def test_wrap_correlated_qubit_loss():
467467
@stim.main
468468
def main_correlated_qubit_loss():
469-
noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
469+
noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2))
470470

471471
@stim.main
472472
def main_correlated_qubit_loss_wrap():
473-
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
473+
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))
474474

475475
assert main_correlated_qubit_loss.callable_region.is_structurally_equal(
476476
main_correlated_qubit_loss_wrap.callable_region

0 commit comments

Comments
 (0)