Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/bloqade/stim/dialects/noise/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def non_stim_corr_error(
prob: tuple[str, ...] = frame.get_values(stmt.probs)
prob_str: str = ", ".join(prob)

res = f"I_ERROR[{stmt.name}:{stmt.nonce}]({prob_str}) " + " ".join(targets)
res = (
f"I_ERROR[{stmt.name}:{emit.correlated_error_count}]({prob_str}) "
+ " ".join(targets)
)
emit.correlated_error_count += 1
emit.writeln(frame, res)

return ()
4 changes: 0 additions & 4 deletions src/bloqade/stim/dialects/noise/stmts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import random

from kirin import ir, types, lowering
from kirin.decl import info, statement

Expand Down Expand Up @@ -91,8 +89,6 @@ class NonStimError(ir.Statement):
class NonStimCorrelatedError(ir.Statement):
name = "NonStimCorrelatedError"
traits = frozenset({lowering.FromPythonCall()})
# nonce must be a unique value, otherwise stim might merge two correlated errors
nonce: int = info.attribute(default_factory=lambda: random.getrandbits(32))
probs: tuple[ir.SSAValue, ...] = info.argument(types.Float)
targets: tuple[ir.SSAValue, ...] = info.argument(types.Int)

Expand Down
2 changes: 2 additions & 0 deletions src/bloqade/stim/emit/stim_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ class EmitStimMain(EmitStr):
keys = ["emit.stim"]
dialects: ir.DialectGroup = field(default_factory=_default_dialect_group)
file: StringIO = field(default_factory=StringIO)
correlation_identifier_offset: int = 0

def initialize(self):
super().initialize()
self.file.truncate(0)
self.file.seek(0)
self.correlated_error_count = self.correlation_identifier_offset
return self

def eval_stmt_fallback(
Expand Down
29 changes: 12 additions & 17 deletions src/bloqade/stim/parse/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,13 @@ def visit_I_ERROR(
# Parse tag
tag_parts = node.tag.split(";", maxsplit=1)[0].split(":", maxsplit=1)
nonstim_name = tag_parts[0]
nonce = 0
if len(tag_parts) == 2:
# This should be a correlated error of the form, e.g.,
# I_ERROR[correlated_loss:<identifier>](0.01) 0 1 2
# The identifier is a unique number that prevents stim from merging
# correlated errors. We discard the identifier, but verify it is an integer.
try:
nonce = int(tag_parts[1])
_ = int(tag_parts[1])
except ValueError:
# String was not an integer
if self.error_unknown_nonstim:
Expand All @@ -643,22 +646,14 @@ def visit_I_ERROR(
f"Unknown non-stim statement name: {nonstim_name!r} ({node!r})"
)
statement_cls = self.nonstim_noise_ops.get(nonstim_name)
stmt = None
if statement_cls is not None:
if issubclass(statement_cls, noise.NonStimCorrelatedError):
stmt = statement_cls(
nonce=nonce,
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
targets=self._get_multiple_qubit_or_rec_ssa(
state, node, node.targets_copy()
),
)
else:
stmt = statement_cls(
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
targets=self._get_multiple_qubit_or_rec_ssa(
state, node, node.targets_copy()
),
)
stmt = statement_cls(
probs=self._get_float_args_ssa(state, node.gate_args_copy()),
targets=self._get_multiple_qubit_or_rec_ssa(
state, node, node.targets_copy()
),
)
return stmt

def visit_CircuitInstruction(
Expand Down
3 changes: 0 additions & 3 deletions src/bloqade/stim/rewrite/squin_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

@dataclass
class SquinNoiseToStim(RewriteRule):
_correlated_loss_counter: int = 0

def rewrite_Statement(self, node: Statement) -> RewriteResult:
match node:
Expand Down Expand Up @@ -135,9 +134,7 @@ def rewrite_CorrelatedQubitLoss(
stim_stmt = stim_noise.CorrelatedQubitLoss(
targets=qubit_idx_ssas,
probs=(stmt.p,),
nonce=self._correlated_loss_counter,
)
self._correlated_loss_counter += 1

return stim_stmt

Expand Down
52 changes: 49 additions & 3 deletions test/stim/dialects/stim/emit/test_stim_noise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from bloqade import stim
from bloqade.stim.emit import EmitStimMain
from bloqade.stim.parse import loads
from bloqade.stim.dialects import noise

from .base import codegen
emit = EmitStimMain()


def codegen(mt):
# method should not have any arguments!
emit.initialize()
emit.run(mt=mt, args=())
return emit.get_output()


def test_noise():
Expand Down Expand Up @@ -45,7 +55,43 @@ def test_qubit_loss():
def test_correlated_qubit_loss():
@stim.main
def test_correlated_qubit_loss():
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))

out = codegen(test_correlated_qubit_loss)
assert out.strip() == "I_ERROR[correlated_loss:3](0.10000000) 0 1 2"
assert out.strip() == "I_ERROR[correlated_loss:0](0.10000000) 0 1 2"


def test_correlated_qubit_loss_multiple():

@stim.main
def test_correlated_qubit_loss_multiple():
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1))
stim.correlated_qubit_loss(probs=(0.1,), targets=(2, 3))

for i in range(2): # repeat the test to ensure the identifier is reset each time
out = codegen(test_correlated_qubit_loss_multiple).strip()
print(out)
assert (
out.strip()
== "I_ERROR[correlated_loss:0](0.10000000) 0 1\n"
+ "I_ERROR[correlated_loss:1](0.10000000) 2 3"
)


def test_correlated_qubit_codegen_roundtrip():
@stim.main
def test():
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))
stim.qubit_loss(probs=(0.2,), targets=(2,))
stim.correlated_qubit_loss(probs=(0.3,), targets=(3, 4))

stim_str = codegen(test)

mt = loads(
stim_str,
nonstim_noise_ops={
"loss": noise.QubitLoss,
"correlated_loss": noise.CorrelatedQubitLoss,
},
)
assert codegen(mt) == stim_str
2 changes: 1 addition & 1 deletion test/stim/parse/test_parse_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_parse_trivial_correlated():

# test roundtrip
out = codegen(mt)
assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:3](0.20000000, 0.30000000) 5 0 1 2"
assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:0](0.20000000, 0.30000000) 5 0 1 2"


def test_qubit_loss():
Expand Down
16 changes: 16 additions & 0 deletions test/stim/passes/test_squin_noise_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,22 @@ def test():
assert codegen(test) == expected


def test_correlated_qubit_loss_codegen_with_offset():

@kernel
def test():
q = sq.qalloc(4)
sq.correlated_qubit_loss(0.1, qubits=q)

SquinToStimPass(test.dialects)(test)

emit = EmitStimMain(correlation_identifier_offset=10)
emit.initialize()
emit.run(mt=test, args=())
stim_str = emit.get_output().strip()
assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3"


def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement:
return method.callable_region.blocks[0].stmts.at(idx)

Expand Down
4 changes: 2 additions & 2 deletions test/stim/wrapper/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ def main_ry_wrap():
def test_wrap_correlated_qubit_loss():
@stim.main
def main_correlated_qubit_loss():
noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
noise.CorrelatedQubitLoss(probs=(0.1,), targets=(0, 1, 2))

@stim.main
def main_correlated_qubit_loss_wrap():
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2), nonce=3)
stim.correlated_qubit_loss(probs=(0.1,), targets=(0, 1, 2))

assert main_correlated_qubit_loss.callable_region.is_structurally_equal(
main_correlated_qubit_loss_wrap.callable_region
Expand Down