Skip to content

Commit b7f159e

Browse files
committed
Automatically rewrite squin noise statements before running simulator (#385)
1 parent 70023bc commit b7f159e

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

src/bloqade/pyqrack/device.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import numpy as np
55
from kirin import ir
6+
from kirin.passes import fold
67

8+
from bloqade.squin import noise as squin_noise
79
from pyqrack.pauli import Pauli
810
from bloqade.device import AbstractSimulatorDevice
911
from bloqade.pyqrack.reg import Measurement, PyQrackQubit
@@ -16,6 +18,7 @@
1618
_default_pyqrack_args,
1719
)
1820
from bloqade.pyqrack.task import PyQrackSimulatorTask
21+
from bloqade.squin.noise.rewrite import RewriteNoiseStmts
1922
from bloqade.analysis.address.lattice import AnyAddress
2023
from bloqade.analysis.address.analysis import AddressAnalysis
2124

@@ -47,14 +50,23 @@ def new_task(
4750
kwargs: dict[str, Any],
4851
memory: MemoryType,
4952
) -> PyQrackSimulatorTask[Params, RetType, MemoryType]:
53+
54+
if squin_noise in mt.dialects:
55+
# NOTE: rewrite noise statements
56+
mt_ = mt.similar(mt.dialects)
57+
RewriteNoiseStmts(mt_.dialects)(mt_)
58+
fold.Fold(mt_.dialects)(mt_)
59+
else:
60+
mt_ = mt
61+
5062
interp = PyQrackInterpreter(
51-
mt.dialects,
63+
mt_.dialects,
5264
memory=memory,
5365
rng_state=self.rng_state,
5466
loss_m_result=self.loss_m_result,
5567
)
5668
return PyQrackSimulatorTask(
57-
kernel=mt, args=args, kwargs=kwargs, pyqrack_interp=interp
69+
kernel=mt_, args=args, kwargs=kwargs, pyqrack_interp=interp
5870
)
5971

6072
def state_vector(

test/cirq_utils/noise/test_noise_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_simple_model(model: cirq.NoiseModel, qubits):
8282
assert math.isclose(pops[1], 0.0, abs_tol=1e-1)
8383
assert math.isclose(pops[2], 0.0, abs_tol=1e-1)
8484

85-
assert pops[0] < 0.5
86-
assert pops[3] < 0.5
85+
assert pops[0] < 0.5001
86+
assert pops[3] < 0.5001
8787
assert pops[1] > 0.0
8888
assert pops[2] > 0.0

test/pyqrack/squin/test_noise.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import math
2+
13
from bloqade import squin
2-
from bloqade.pyqrack import PyQrack, PyQrackQubit
4+
from bloqade.pyqrack import PyQrack, PyQrackQubit, StackMemorySimulator
35
from bloqade.squin.noise.stmts import NoiseChannel, StochasticUnitaryChannel
46
from bloqade.squin.noise.rewrite import RewriteNoiseStmts
57

@@ -157,3 +159,53 @@ def main():
157159

158160
target = PyQrack(1)
159161
target.run(main)
162+
163+
164+
def test_without_rewrite():
165+
@squin.kernel
166+
def main():
167+
q = squin.qubit.new(1)
168+
x = squin.op.x()
169+
squin.qubit.apply(x, q[0])
170+
171+
x_err = squin.noise.pauli_error(x, 0.1)
172+
squin.qubit.apply(x_err, q[0])
173+
return q
174+
175+
sim = StackMemorySimulator(min_qubits=1)
176+
sim.state_vector(main)
177+
178+
main.print()
179+
180+
# make sure the method wasn't updated in-place
181+
stmts = list(main.callable_region.blocks[0].stmts)
182+
assert sum([isinstance(s, StochasticUnitaryChannel) for s in stmts]) == 0
183+
assert sum([isinstance(s, squin.noise.stmts.PauliError) for s in stmts]) == 1
184+
185+
186+
def test_pauli_string_error():
187+
@squin.kernel
188+
def main():
189+
q = squin.qubit.new(2)
190+
x = squin.op.x()
191+
192+
err = squin.noise.pauli_error(x, 1.0)
193+
squin.qubit.apply(err, q[0])
194+
195+
s = squin.op.pauli_string(string="XX")
196+
err2 = squin.noise.pauli_error(s, 1.0)
197+
squin.qubit.apply(err2, q)
198+
199+
main.print()
200+
201+
RewriteNoiseStmts(main.dialects)(main)
202+
203+
main.print()
204+
205+
sim = StackMemorySimulator(min_qubits=2)
206+
207+
ket = sim.state_vector(main)
208+
209+
print(ket)
210+
211+
assert math.isclose(abs(ket[2]) ** 2, 1.0, abs_tol=1e-5)

0 commit comments

Comments
 (0)