Skip to content

Commit 7e56608

Browse files
committed
add closure for kernels that have arguments
1 parent 8b952b9 commit 7e56608

File tree

2 files changed

+169
-38
lines changed

2 files changed

+169
-38
lines changed

src/bloqade/pyqrack/device.py

Lines changed: 130 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,47 @@
1-
from typing import Any, TypeVar, ParamSpec
1+
from typing import Any, Generic, TypeVar, ParamSpec, cast
22
from dataclasses import field, dataclass
33

44
import numpy as np
55
from kirin import ir
6+
from kirin.dialects import py, func
67

8+
from bloqade.noise import native
79
from pyqrack.pauli import Pauli
810
from bloqade.device import AbstractSimulatorDevice
911
from bloqade.pyqrack.reg import Measurement, PyQrackQubit
1012
from bloqade.pyqrack.base import (
11-
MemoryABC,
1213
StackMemory,
1314
DynamicMemory,
1415
PyQrackOptions,
1516
PyQrackInterpreter,
1617
_default_pyqrack_args,
1718
)
18-
from bloqade.pyqrack.task import PyQrackSimulatorTask
19+
from bloqade.pyqrack.task import PyQrackSimulatorTask, PyQrackNoiseSimulatorTask
20+
from bloqade.qasm2.passes import NoisePass, QASM2Fold, UOpToParallel
21+
from bloqade.analysis.fidelity import FidelityAnalysis
1922
from bloqade.analysis.address.lattice import AnyAddress
2023
from bloqade.analysis.address.analysis import AddressAnalysis
2124

2225
RetType = TypeVar("RetType")
2326
Params = ParamSpec("Params")
2427

28+
PyQrackSimulatorTaskType = TypeVar(
29+
"PyQrackSimulatorTaskType",
30+
bound=PyQrackSimulatorTask,
31+
)
32+
2533

2634
@dataclass
27-
class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
35+
class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTaskType]):
2836
options: PyQrackOptions = field(default_factory=_default_pyqrack_args)
2937
loss_m_result: Measurement = field(default=Measurement.One, kw_only=True)
3038
rng_state: np.random.Generator = field(
3139
default_factory=np.random.default_rng, kw_only=True
3240
)
3341

34-
MemoryType = TypeVar("MemoryType", bound=MemoryABC)
35-
3642
def __post_init__(self):
3743
self.options = PyQrackOptions({**_default_pyqrack_args(), **self.options})
3844

39-
def new_task(
40-
self,
41-
mt: ir.Method[Params, RetType],
42-
args: tuple[Any, ...],
43-
kwargs: dict[str, Any],
44-
memory: MemoryType,
45-
) -> PyQrackSimulatorTask[Params, RetType, MemoryType]:
46-
interp = PyQrackInterpreter(
47-
mt.dialects,
48-
memory=memory,
49-
rng_state=self.rng_state,
50-
loss_m_result=self.loss_m_result,
51-
)
52-
return PyQrackSimulatorTask(
53-
kernel=mt, args=args, kwargs=kwargs, pyqrack_interp=interp
54-
)
55-
5645
def state_vector(
5746
self,
5847
kernel: ir.Method[Params, RetType],
@@ -98,7 +87,7 @@ def pauli_expectation(pauli: list[Pauli], qubits: list[PyQrackQubit]) -> float:
9887

9988

10089
@dataclass
101-
class StackMemorySimulator(PyQrackSimulatorBase):
90+
class StackMemorySimulator(PyQrackSimulatorBase[PyQrackSimulatorTask]):
10291
"""PyQrack simulator device with precalculated stack of qubits."""
10392

10493
min_qubits: int = field(default=0, kw_only=True)
@@ -129,11 +118,20 @@ def task(
129118
total=num_qubits,
130119
)
131120

132-
return self.new_task(kernel, args, kwargs, memory)
121+
pyqrack_interp = PyQrackInterpreter(
122+
kernel.dialects,
123+
memory=memory,
124+
rng_state=self.rng_state,
125+
loss_m_result=self.loss_m_result,
126+
)
127+
128+
return PyQrackSimulatorTask(
129+
kernel=kernel, args=args, kwargs=kwargs, pyqrack_interp=pyqrack_interp
130+
)
133131

134132

135133
@dataclass
136-
class DynamicMemorySimulator(PyQrackSimulatorBase):
134+
class DynamicMemorySimulator(PyQrackSimulatorBase[PyQrackSimulatorTask]):
137135
"""PyQrack simulator device with dynamic qubit allocation."""
138136

139137
def task(
@@ -145,20 +143,114 @@ def task(
145143
if kwargs is None:
146144
kwargs = {}
147145

148-
memory = DynamicMemory(self.options.copy())
149-
return self.new_task(kernel, args, kwargs, memory)
146+
pyqrack_interp = PyQrackInterpreter(
147+
kernel.dialects,
148+
memory=DynamicMemory(self.options.copy()),
149+
rng_state=self.rng_state,
150+
loss_m_result=self.loss_m_result,
151+
)
152+
153+
return PyQrackSimulatorTask(
154+
kernel=kernel,
155+
args=args,
156+
kwargs=kwargs,
157+
pyqrack_interp=pyqrack_interp,
158+
)
159+
160+
161+
def arg_closure(
162+
kernel: ir.Method[Params, RetType], args: tuple[Any, ...], kwargs: dict[str, Any]
163+
) -> ir.Method[..., RetType]:
164+
"""Create a closure for the arguments of the kernel."""
165+
166+
func_body = ir.Region(block := ir.Block())
167+
inputs: list[ir.ResultValue] = []
168+
for arg in args:
169+
block.stmts.append(const_stmt := py.Constant(arg))
170+
inputs.append(const_stmt.result)
171+
172+
kw_names: list[str] = []
173+
for key, value in kwargs.items():
174+
block.stmts.append(const_stmt := py.Constant(value))
175+
kw_names.append(key)
176+
inputs.append(const_stmt.result)
177+
178+
block.stmts.append(
179+
invoke_stmt := func.Invoke(
180+
inputs=tuple(inputs),
181+
callee=kernel,
182+
kwargs=tuple(kw_names),
183+
purity=False,
184+
)
185+
)
186+
block.stmts.append(func.Return(invoke_stmt.result))
187+
188+
code = func.Function(
189+
sym_name="closure",
190+
signature=func.Signature((), kernel.return_type),
191+
body=func_body,
192+
)
193+
return ir.Method(None, None, "closure", [], kernel.dialects, code)
194+
195+
196+
NoiseModelType = TypeVar("NoiseModelType", bound=native.MoveNoiseModelABC)
197+
198+
199+
@dataclass
200+
class NoiseSimulatorBase(
201+
PyQrackSimulatorBase[PyQrackNoiseSimulatorTask], Generic[NoiseModelType]
202+
):
203+
noise_model: NoiseModelType = field(default_factory=native.TwoRowZoneModel)
204+
gate_noise_params: native.GateNoiseParams = field(
205+
default_factory=native.GateNoiseParams
206+
)
207+
optimize_parallel_gates: bool = field(default=True, kw_only=True)
208+
decompose_native_gates: bool = field(default=True, kw_only=True)
209+
210+
def task(
211+
self,
212+
kernel: ir.Method[Params, RetType],
213+
args: tuple[Any, ...] = (),
214+
kwargs: dict[str, Any] | None = None,
215+
):
216+
if kwargs is None:
217+
kwargs = {}
218+
219+
if len(args) > 0 or len(kwargs) > 0:
220+
folded_kernel = arg_closure(kernel, args, kwargs)
221+
args = ()
222+
kwargs = {}
223+
else:
224+
folded_kernel = cast(ir.Method[..., RetType], kernel)
150225

226+
QASM2Fold(folded_kernel.dialects).fixpoint(folded_kernel)
151227

152-
def test():
153-
from bloqade.qasm2 import extended
228+
if self.optimize_parallel_gates:
229+
UOpToParallel(folded_kernel.dialects)(folded_kernel)
154230

155-
@extended
156-
def main():
157-
return 1
231+
if native.dialect not in folded_kernel.dialects:
232+
noise_pass = NoisePass(
233+
kernel.dialects,
234+
self.noise_model,
235+
self.gate_noise_params,
236+
)
158237

159-
@extended
160-
def obs(result: int) -> int:
161-
return result
238+
noise_pass(folded_kernel)
239+
folded_kernel = folded_kernel.similar(
240+
folded_kernel.dialects.add(native.dialect)
241+
)
242+
243+
pyqrack_interp = PyQrackInterpreter(
244+
folded_kernel.dialects,
245+
memory=DynamicMemory(self.options.copy()),
246+
rng_state=self.rng_state,
247+
loss_m_result=self.loss_m_result,
248+
)
162249

163-
res = DynamicMemorySimulator().task(main)
164-
return res.run()
250+
return PyQrackNoiseSimulatorTask(
251+
kernel=folded_kernel,
252+
args=args,
253+
kwargs=kwargs,
254+
pyqrack_interp=pyqrack_interp,
255+
fidelity_scorer=FidelityAnalysis(kernel.dialects),
256+
)

src/bloqade/pyqrack/task.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from typing import TypeVar, ParamSpec, cast
22
from dataclasses import dataclass
33

4+
import numpy as np
5+
46
from bloqade.task import AbstractSimulatorTask
57
from bloqade.pyqrack.base import (
68
MemoryABC,
79
PyQrackInterpreter,
810
)
11+
from bloqade.analysis.fidelity import FidelityAnalysis
912

1013
RetType = TypeVar("RetType")
1114
Param = ParamSpec("Param")
@@ -36,3 +39,39 @@ def state_vector(self) -> list[complex]:
3639
"""Returns the state vector of the simulator."""
3740
self.run()
3841
return self.state.sim_reg.out_ket()
42+
43+
44+
@dataclass
45+
class PyQrackNoiseSimulatorTask(PyQrackSimulatorTask[Param, RetType, MemoryType]):
46+
"""PyQrack noise simulator task for Bloqade."""
47+
48+
fidelity_scorer: FidelityAnalysis
49+
50+
@dataclass(frozen=True)
51+
class FidelityResult:
52+
"""Stores the results of the fidelity analysis."""
53+
54+
gate_fidelity: float
55+
"""The global fidelity of the circuit execution."""
56+
atom_survival_probability: list[float]
57+
"""The survival probability of each qubit in the circuit."""
58+
59+
@property
60+
def typical_survival_probability(self) -> float:
61+
"""Returns the typical survival probability of the qubits."""
62+
return float(np.median(self.atom_survival_probability))
63+
64+
def run(self) -> RetType:
65+
return self.pyqrack_interp.run(
66+
self.kernel,
67+
args=self.args,
68+
kwargs=self.kwargs,
69+
)
70+
71+
def fidelity(self) -> FidelityResult:
72+
_, _ = self.fidelity_scorer.run_analysis(self.kernel)
73+
74+
return self.FidelityResult(
75+
self.fidelity_scorer.gate_fidelity,
76+
self.fidelity_scorer.atom_survival_probability,
77+
)

0 commit comments

Comments
 (0)