|
10 | 10 | from bloqade.squin import noise, kernel |
11 | 11 | from bloqade.types import Qubit, QubitType |
12 | 12 | from bloqade.stim.emit import EmitStimMain |
| 13 | +from bloqade.stim.parse import loads |
13 | 14 | from bloqade.stim.passes import SquinToStimPass, flatten |
14 | 15 | from bloqade.stim.rewrite import SquinNoiseToStim |
15 | 16 | from bloqade.squin.rewrite import WrapAddressAnalysis |
| 17 | +from bloqade.stim.dialects import noise as stim_noise |
16 | 18 | from bloqade.analysis.address import AddressAnalysis |
17 | 19 |
|
18 | 20 |
|
@@ -242,6 +244,45 @@ def test(): |
242 | 244 | assert codegen(test) == expected |
243 | 245 |
|
244 | 246 |
|
| 247 | +def test_correlated_qubit_loss_codegen_roundtrip(): |
| 248 | + |
| 249 | + @kernel |
| 250 | + def test(): |
| 251 | + q = sq.qalloc(4) |
| 252 | + sq.correlated_qubit_loss(0.1, qubits=q[:2]) |
| 253 | + sq.qubit_loss(0.2, qubit=q[2]) |
| 254 | + sq.broadcast.qubit_loss(0.3, qubits=q) |
| 255 | + sq.broadcast.correlated_qubit_loss(0.1, qubits=[q[:2], q[2:]]) |
| 256 | + |
| 257 | + SquinToStimPass(test.dialects)(test) |
| 258 | + stim_str = codegen(test) |
| 259 | + |
| 260 | + mt = loads( |
| 261 | + stim_str, |
| 262 | + nonstim_noise_ops={ |
| 263 | + "loss": stim_noise.QubitLoss, |
| 264 | + "correlated_loss": stim_noise.CorrelatedQubitLoss, |
| 265 | + }, |
| 266 | + ) |
| 267 | + assert codegen(mt) == stim_str |
| 268 | + |
| 269 | + |
| 270 | +def test_correlated_qubit_loss_codegen_with_offset(): |
| 271 | + |
| 272 | + @kernel |
| 273 | + def test(): |
| 274 | + q = sq.qalloc(4) |
| 275 | + sq.correlated_qubit_loss(0.1, qubits=q) |
| 276 | + |
| 277 | + SquinToStimPass(test.dialects)(test) |
| 278 | + |
| 279 | + emit = EmitStimMain(correlation_identifier_offset=10) |
| 280 | + emit.initialize() |
| 281 | + emit.run(mt=test, args=()) |
| 282 | + stim_str = emit.get_output().strip() |
| 283 | + assert stim_str == "I_ERROR[correlated_loss:10](0.10000000) 0 1 2 3" |
| 284 | + |
| 285 | + |
245 | 286 | def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement: |
246 | 287 | return method.callable_region.blocks[0].stmts.at(idx) |
247 | 288 |
|
|
0 commit comments