Skip to content

Commit b817f85

Browse files
committed
adding pyqrack impl into circuits repo
1 parent 86d5da0 commit b817f85

File tree

20 files changed

+1284
-0
lines changed

20 files changed

+1284
-0
lines changed

_typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ AttributeIDSupressMenu = "AttributeIDSupressMenu"
1212
Braket = "Braket"
1313
mch = "mch"
1414
IY = "IY"
15+
ket = "ket"

pyproject.toml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ dependencies = [
1717
"rich>=13.9.4",
1818
"pydantic>=1.3.0,<2.11.0",
1919
"pandas>=2.2.3",
20+
"pyqrack[pyqrack]>=1.38.2",
21+
"pyqrack-cpu[pyqrack-cpu]>=1.38.2",
2022
]
2123

2224
[project.optional-dependencies]
@@ -76,6 +78,26 @@ doc = [
7678
examples = [
7779
"networkx>=3.4.2",
7880
]
81+
dev-linux = [
82+
"cirq-core[contrib]>=1.4.1",
83+
"lark>=1.2.2",
84+
"pyqrack-cpu>=1.38.2",
85+
"qbraid>=0.9.5",
86+
"ffmpeg>=1.4",
87+
"matplotlib>=3.10.0",
88+
"pyqt5>=5.15.11",
89+
"tqdm>=4.67.1",
90+
]
91+
dev-mac-arm = [
92+
"cirq-core[contrib]>=1.4.1",
93+
"ffmpeg>=1.4",
94+
"lark>=1.2.2",
95+
"matplotlib>=3.10.0",
96+
"pyqrack>=1.38.2",
97+
"pyqt5>=5.15.11",
98+
"qbraid>=0.9.5",
99+
"tqdm>=4.67.1",
100+
]
79101

80102
[tool.isort]
81103
profile = "black"

src/bloqade/pyqrack/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from .reg import (
2+
CBitRef as CBitRef,
3+
CRegister as CRegister,
4+
PyQrackReg as PyQrackReg,
5+
QubitState as QubitState,
6+
Measurement as Measurement,
7+
PyQrackQubit as PyQrackQubit,
8+
)
9+
from .base import (
10+
StackMemory as StackMemory,
11+
DynamicMemory as DynamicMemory,
12+
PyQrackInterpreter as PyQrackInterpreter,
13+
)
14+
15+
# NOTE: The following import is for registering the method tables
16+
from .noise import native as native
17+
from .qasm2 import uop as uop, core as core, parallel as parallel
18+
from .target import PyQrack as PyQrack

src/bloqade/pyqrack/base.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import abc
2+
import typing
3+
from dataclasses import field, dataclass
4+
from unittest.mock import Mock
5+
6+
import numpy as np
7+
from kirin.interp import Interpreter
8+
from typing_extensions import Self
9+
from kirin.interp.exceptions import InterpreterError
10+
11+
from bloqade.pyqrack.reg import Measurement
12+
13+
if typing.TYPE_CHECKING:
14+
from pyqrack import QrackSimulator
15+
16+
17+
class PyQrackOptions(typing.TypedDict):
18+
qubitCount: int
19+
isTensorNetwork: bool
20+
isSchmidtDecomposeMulti: bool
21+
isSchmidtDecompose: bool
22+
isStabilizerHybrid: bool
23+
isBinaryDecisionTree: bool
24+
isPaged: bool
25+
isCpuGpuHybrid: bool
26+
isOpenCL: bool
27+
28+
29+
def _default_pyqrack_args() -> PyQrackOptions:
30+
return PyQrackOptions(
31+
qubitCount=-1,
32+
isTensorNetwork=False,
33+
isSchmidtDecomposeMulti=True,
34+
isSchmidtDecompose=True,
35+
isStabilizerHybrid=True,
36+
isBinaryDecisionTree=True,
37+
isPaged=True,
38+
isCpuGpuHybrid=True,
39+
isOpenCL=True,
40+
)
41+
42+
43+
@dataclass
44+
class MemoryABC(abc.ABC):
45+
pyqrack_options: PyQrackOptions = field(default_factory=_default_pyqrack_args)
46+
sim_reg: "QrackSimulator" = field(init=False)
47+
48+
@abc.abstractmethod
49+
def allocate(self, n_qubits: int) -> tuple[int, ...]:
50+
"""Allocate `n_qubits` qubits and return their ids."""
51+
...
52+
53+
def reset(self):
54+
"""Reset the memory, releasing all qubits."""
55+
from pyqrack import QrackSimulator
56+
57+
# do not reset the simulator it might be used by
58+
# results of the simulation
59+
self.sim_reg = QrackSimulator(**self.pyqrack_options)
60+
61+
62+
@dataclass
63+
class MockMemory(MemoryABC):
64+
"""Mock memory for testing purposes."""
65+
66+
allocated: int = field(init=False, default=0)
67+
68+
def allocate(self, n_qubits: int):
69+
allocated = self.allocated + n_qubits
70+
result = tuple(range(self.allocated, allocated))
71+
self.allocated = allocated
72+
return result
73+
74+
def reset(self):
75+
self.allocated = 0
76+
self.sim_reg = Mock()
77+
78+
79+
@dataclass
80+
class StackMemory(MemoryABC):
81+
total: int = field(kw_only=True)
82+
allocated: int = field(init=False, default=0)
83+
84+
def allocate(self, n_qubits: int):
85+
curr_allocated = self.allocated
86+
self.allocated += n_qubits
87+
88+
if self.allocated > self.total:
89+
raise InterpreterError(
90+
f"qubit allocation exceeds memory, "
91+
f"{self.total} qubits, "
92+
f"{self.allocated} allocated"
93+
)
94+
95+
return tuple(range(curr_allocated, self.allocated))
96+
97+
def reset(self):
98+
super().reset()
99+
self.allocated = 0
100+
101+
102+
@dataclass
103+
class DynamicMemory(MemoryABC):
104+
def __post_init__(self):
105+
self.reset()
106+
107+
if self.sim_reg.is_tensor_network:
108+
raise ValueError("DynamicMemory does not support tensor networks")
109+
110+
def allocate(self, n_qubits: int):
111+
start = self.sim_reg.num_qubits()
112+
for i in range(start, start + n_qubits):
113+
self.sim_reg.allocate_qubit(i)
114+
115+
return tuple(range(start, start + n_qubits))
116+
117+
118+
@dataclass
119+
class PyQrackInterpreter(Interpreter):
120+
keys = ["pyqrack", "main"]
121+
memory: MemoryABC = field(kw_only=True)
122+
rng_state: np.random.Generator = field(
123+
default_factory=np.random.default_rng, kw_only=True
124+
)
125+
loss_m_result: Measurement = field(default=Measurement.One, kw_only=True)
126+
"""The value of a measurement result when a qubit is lost."""
127+
128+
def initialize(self) -> Self:
129+
super().initialize()
130+
self.memory.reset() # reset allocated qubits
131+
return self

src/bloqade/pyqrack/noise/__init__.py

Whitespace-only changes.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import TYPE_CHECKING, List
2+
3+
from kirin import interp
4+
5+
from bloqade.noise import native
6+
from bloqade.pyqrack import PyQrackInterpreter, reg
7+
8+
if TYPE_CHECKING:
9+
from pyqrack import QrackSimulator
10+
11+
12+
@native.dialect.register(key="pyqrack")
13+
class PyQrackMethods(interp.MethodTable):
14+
def apply_pauli_error(
15+
self,
16+
interp: PyQrackInterpreter,
17+
qarg: reg.PyQrackQubit,
18+
px: float,
19+
py: float,
20+
pz: float,
21+
):
22+
p = [1 - (px + py + pz), px, py, pz]
23+
24+
assert all(0 <= x <= 1 for x in p), "Invalid Pauli error probabilities"
25+
26+
which = interp.rng_state.choice(["i", "x", "y", "z"], p=p)
27+
28+
if which == "i":
29+
return
30+
31+
getattr(qarg.sim_reg, which)(qarg.addr)
32+
33+
@interp.impl(native.PauliChannel)
34+
def single_qubit_error_channel(
35+
self,
36+
interp: PyQrackInterpreter,
37+
frame: interp.Frame,
38+
stmt: native.PauliChannel,
39+
):
40+
qargs: List[reg.PyQrackQubit] = frame.get(stmt.qargs)
41+
42+
active_qubits = (qarg for qarg in qargs if qarg.is_active())
43+
44+
for qarg in active_qubits:
45+
self.apply_pauli_error(interp, qarg, stmt.px, stmt.py, stmt.pz)
46+
47+
return ()
48+
49+
@interp.impl(native.CZPauliChannel)
50+
def cz_pauli_channel(
51+
self,
52+
interp: PyQrackInterpreter,
53+
frame: interp.Frame,
54+
stmt: native.CZPauliChannel,
55+
):
56+
57+
qargs: List[reg.PyQrackQubit] = frame.get(stmt.qargs)
58+
ctrls: List[reg.PyQrackQubit] = frame.get(stmt.ctrls)
59+
60+
if stmt.paired:
61+
valid_pairs = (
62+
(ctrl, qarg)
63+
for ctrl, qarg in zip(ctrls, qargs)
64+
if ctrl.is_active() and qarg.is_active()
65+
)
66+
else:
67+
valid_pairs = (
68+
(ctrl, qarg)
69+
for ctrl, qarg in zip(ctrls, qargs)
70+
if ctrl.is_active() ^ qarg.is_active()
71+
)
72+
73+
for ctrl, qarg in valid_pairs:
74+
if ctrl.is_active():
75+
self.apply_pauli_error(
76+
interp, ctrl, stmt.px_ctrl, stmt.py_ctrl, stmt.pz_ctrl
77+
)
78+
79+
if qarg.is_active():
80+
self.apply_pauli_error(
81+
interp, qarg, stmt.px_qarg, stmt.py_qarg, stmt.pz_qarg
82+
)
83+
84+
return ()
85+
86+
@interp.impl(native.AtomLossChannel)
87+
def atom_loss_channel(
88+
self,
89+
interp: PyQrackInterpreter,
90+
frame: interp.Frame,
91+
stmt: native.AtomLossChannel,
92+
):
93+
qargs: List[reg.PyQrackQubit["QrackSimulator"]] = frame.get(stmt.qargs)
94+
95+
active_qubits = (qarg for qarg in qargs if qarg.is_active())
96+
97+
for qarg in active_qubits:
98+
if interp.rng_state.uniform() <= stmt.prob:
99+
sim_reg = qarg.ref.sim_reg
100+
sim_reg.force_m(qarg.addr, 0)
101+
qarg.drop()
102+
103+
return ()

src/bloqade/pyqrack/qasm2/__init__.py

Whitespace-only changes.

src/bloqade/pyqrack/qasm2/core.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from kirin import interp
2+
3+
from bloqade.pyqrack.reg import (
4+
CBitRef,
5+
CRegister,
6+
PyQrackReg,
7+
QubitState,
8+
Measurement,
9+
PyQrackQubit,
10+
)
11+
from bloqade.pyqrack.base import PyQrackInterpreter
12+
from bloqade.qasm2.dialects import core
13+
14+
15+
@core.dialect.register(key="pyqrack")
16+
class PyQrackMethods(interp.MethodTable):
17+
18+
@interp.impl(core.QRegNew)
19+
def qreg_new(
20+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegNew
21+
):
22+
n_qubits: int = frame.get(stmt.n_qubits)
23+
return (
24+
PyQrackReg(
25+
size=n_qubits,
26+
sim_reg=interp.memory.sim_reg,
27+
addrs=interp.memory.allocate(n_qubits),
28+
qubit_state=[QubitState.Active] * n_qubits,
29+
),
30+
)
31+
32+
@interp.impl(core.CRegNew)
33+
def creg_new(
34+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.CRegNew
35+
):
36+
n_bits: int = frame.get(stmt.n_bits)
37+
return (CRegister(size=n_bits),)
38+
39+
@interp.impl(core.QRegGet)
40+
def qreg_get(
41+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.QRegGet
42+
):
43+
return (PyQrackQubit(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)
44+
45+
@interp.impl(core.CRegGet)
46+
def creg_get(
47+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.CRegGet
48+
):
49+
return (CBitRef(ref=frame.get(stmt.reg), pos=frame.get(stmt.idx)),)
50+
51+
@interp.impl(core.Measure)
52+
def measure(
53+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Measure
54+
):
55+
qarg: PyQrackQubit = frame.get(stmt.qarg)
56+
carg: CBitRef = frame.get(stmt.carg)
57+
if qarg.is_active():
58+
carg.set_value(Measurement(qarg.sim_reg.m(qarg.addr)))
59+
else:
60+
carg.set_value(interp.loss_m_result)
61+
62+
return ()
63+
64+
@interp.impl(core.Reset)
65+
def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.Reset):
66+
qarg: PyQrackQubit = frame.get(stmt.qarg)
67+
qarg.sim_reg.force_m(qarg.addr, 0)
68+
return ()
69+
70+
@interp.impl(core.CRegEq)
71+
def creg_eq(
72+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: core.CRegEq
73+
):
74+
lhs: CRegister = frame.get(stmt.lhs)
75+
rhs: CRegister = frame.get(stmt.rhs)
76+
if len(lhs) != len(rhs):
77+
return (False,)
78+
79+
return (all(left is right for left, right in zip(lhs, rhs)),)

0 commit comments

Comments
 (0)