Skip to content

Commit cdf3280

Browse files
david-plweinbe58
andauthored
Implement pyqrack interpreter methods for squin dialect (#207)
Closes #185 --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent b310f03 commit cdf3280

File tree

13 files changed

+1412
-15
lines changed

13 files changed

+1412
-15
lines changed

src/bloqade/pyqrack/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
CRegister as CRegister,
44
QubitState as QubitState,
55
Measurement as Measurement,
6+
PyQrackWire as PyQrackWire,
67
PyQrackQubit as PyQrackQubit,
78
)
89
from .base import (
@@ -14,4 +15,5 @@
1415
# NOTE: The following import is for registering the method tables
1516
from .noise import native as native
1617
from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
18+
from .squin import op as op, qubit as qubit
1719
from .target import PyQrack as PyQrack

src/bloqade/pyqrack/reg.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ def is_active(self) -> bool:
7070
def drop(self):
7171
"""Drop the qubit in-place."""
7272
self.state = QubitState.Lost
73+
74+
75+
@dataclass
76+
class PyQrackWire:
77+
qubit: PyQrackQubit

src/bloqade/pyqrack/squin/__init__.py

Whitespace-only changes.

src/bloqade/pyqrack/squin/op.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from kirin import interp
2+
3+
from bloqade.squin import op
4+
from bloqade.pyqrack.base import PyQrackInterpreter
5+
6+
from .runtime import (
7+
SnRuntime,
8+
SpRuntime,
9+
U3Runtime,
10+
RotRuntime,
11+
KronRuntime,
12+
MultRuntime,
13+
ScaleRuntime,
14+
AdjointRuntime,
15+
ControlRuntime,
16+
PhaseOpRuntime,
17+
IdentityRuntime,
18+
OperatorRuntime,
19+
ProjectorRuntime,
20+
OperatorRuntimeABC,
21+
PauliStringRuntime,
22+
)
23+
24+
25+
@op.dialect.register(key="pyqrack")
26+
class PyQrackMethods(interp.MethodTable):
27+
28+
@interp.impl(op.stmts.Kron)
29+
def kron(
30+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Kron
31+
) -> tuple[OperatorRuntimeABC]:
32+
lhs = frame.get(stmt.lhs)
33+
rhs = frame.get(stmt.rhs)
34+
return (KronRuntime(lhs, rhs),)
35+
36+
@interp.impl(op.stmts.Mult)
37+
def mult(
38+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Mult
39+
) -> tuple[OperatorRuntimeABC]:
40+
lhs = frame.get(stmt.lhs)
41+
rhs = frame.get(stmt.rhs)
42+
return (MultRuntime(lhs, rhs),)
43+
44+
@interp.impl(op.stmts.Adjoint)
45+
def adjoint(
46+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Adjoint
47+
) -> tuple[OperatorRuntimeABC]:
48+
op = frame.get(stmt.op)
49+
return (AdjointRuntime(op),)
50+
51+
@interp.impl(op.stmts.Scale)
52+
def scale(
53+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Scale
54+
) -> tuple[OperatorRuntimeABC]:
55+
op = frame.get(stmt.op)
56+
factor = frame.get(stmt.factor)
57+
return (ScaleRuntime(op, factor),)
58+
59+
@interp.impl(op.stmts.Control)
60+
def control(
61+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Control
62+
) -> tuple[OperatorRuntimeABC]:
63+
op = frame.get(stmt.op)
64+
n_controls = stmt.n_controls
65+
rt = ControlRuntime(
66+
op=op,
67+
n_controls=n_controls,
68+
)
69+
return (rt,)
70+
71+
@interp.impl(op.stmts.Rot)
72+
def rot(
73+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Rot
74+
) -> tuple[OperatorRuntimeABC]:
75+
axis = frame.get(stmt.axis)
76+
angle = frame.get(stmt.angle)
77+
return (RotRuntime(axis, angle),)
78+
79+
@interp.impl(op.stmts.Identity)
80+
def identity(
81+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity
82+
) -> tuple[OperatorRuntimeABC]:
83+
return (IdentityRuntime(sites=stmt.sites),)
84+
85+
@interp.impl(op.stmts.PhaseOp)
86+
@interp.impl(op.stmts.ShiftOp)
87+
def phaseop(
88+
self,
89+
interp: PyQrackInterpreter,
90+
frame: interp.Frame,
91+
stmt: op.stmts.PhaseOp | op.stmts.ShiftOp,
92+
) -> tuple[OperatorRuntimeABC]:
93+
theta = frame.get(stmt.theta)
94+
global_ = isinstance(stmt, op.stmts.PhaseOp)
95+
return (PhaseOpRuntime(theta, global_=global_),)
96+
97+
@interp.impl(op.stmts.X)
98+
@interp.impl(op.stmts.Y)
99+
@interp.impl(op.stmts.Z)
100+
@interp.impl(op.stmts.H)
101+
@interp.impl(op.stmts.S)
102+
@interp.impl(op.stmts.T)
103+
def operator(
104+
self,
105+
interp: PyQrackInterpreter,
106+
frame: interp.Frame,
107+
stmt: (
108+
op.stmts.X | op.stmts.Y | op.stmts.Z | op.stmts.H | op.stmts.S | op.stmts.T
109+
),
110+
) -> tuple[OperatorRuntimeABC]:
111+
return (OperatorRuntime(method_name=stmt.name.lower()),)
112+
113+
@interp.impl(op.stmts.P0)
114+
@interp.impl(op.stmts.P1)
115+
def projector(
116+
self,
117+
interp: PyQrackInterpreter,
118+
frame: interp.Frame,
119+
stmt: op.stmts.P0 | op.stmts.P1,
120+
) -> tuple[OperatorRuntimeABC]:
121+
state = isinstance(stmt, op.stmts.P1)
122+
return (ProjectorRuntime(to_state=state),)
123+
124+
@interp.impl(op.stmts.Sp)
125+
def sp(
126+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sp
127+
) -> tuple[OperatorRuntimeABC]:
128+
return (SpRuntime(),)
129+
130+
@interp.impl(op.stmts.Sn)
131+
def sn(
132+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sn
133+
) -> tuple[OperatorRuntimeABC]:
134+
return (SnRuntime(),)
135+
136+
@interp.impl(op.stmts.U3)
137+
def u3(
138+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.U3
139+
) -> tuple[OperatorRuntimeABC]:
140+
theta = frame.get(stmt.theta)
141+
phi = frame.get(stmt.phi)
142+
lam = frame.get(stmt.lam)
143+
return (U3Runtime(theta, phi, lam),)
144+
145+
@interp.impl(op.stmts.PauliString)
146+
def clifford_string(
147+
self,
148+
interp: PyQrackInterpreter,
149+
frame: interp.Frame,
150+
stmt: op.stmts.PauliString,
151+
) -> tuple[OperatorRuntimeABC]:
152+
string = stmt.string
153+
ops = [OperatorRuntime(method_name=name.lower()) for name in stmt.string]
154+
return (PauliStringRuntime(string, ops),)

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Any
2+
3+
from kirin import interp
4+
from kirin.dialects import ilist
5+
6+
from bloqade.squin import qubit
7+
from bloqade.pyqrack.reg import QubitState, PyQrackQubit
8+
from bloqade.pyqrack.base import PyQrackInterpreter
9+
10+
from .runtime import OperatorRuntimeABC
11+
12+
13+
@qubit.dialect.register(key="pyqrack")
14+
class PyQrackMethods(interp.MethodTable):
15+
@interp.impl(qubit.New)
16+
def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
17+
n_qubits: int = frame.get(stmt.n_qubits)
18+
qreg = ilist.IList(
19+
[
20+
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
21+
for i in interp.memory.allocate(n_qubits=n_qubits)
22+
]
23+
)
24+
return (qreg,)
25+
26+
@interp.impl(qubit.Apply)
27+
def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
28+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
29+
operator: OperatorRuntimeABC = frame.get(stmt.operator)
30+
operator.apply(*qubits)
31+
32+
@interp.impl(qubit.Broadcast)
33+
def broadcast(
34+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Broadcast
35+
):
36+
operator: OperatorRuntimeABC = frame.get(stmt.operator)
37+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
38+
operator.broadcast_apply(qubits)
39+
40+
def _measure_qubit(self, qbit: PyQrackQubit):
41+
if qbit.is_active():
42+
return bool(qbit.sim_reg.m(qbit.addr))
43+
44+
@interp.impl(qubit.MeasureQubitList)
45+
def measure_qubit_list(
46+
self,
47+
interp: PyQrackInterpreter,
48+
frame: interp.Frame,
49+
stmt: qubit.MeasureQubitList,
50+
):
51+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
52+
result = ilist.IList([self._measure_qubit(qbit) for qbit in qubits])
53+
return (result,)
54+
55+
@interp.impl(qubit.MeasureQubit)
56+
def measure_qubit(
57+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit
58+
):
59+
qbit: PyQrackQubit = frame.get(stmt.qubit)
60+
result = self._measure_qubit(qbit)
61+
return (result,)
62+
63+
@interp.impl(qubit.MeasureAndReset)
64+
def measure_and_reset(
65+
self,
66+
interp: PyQrackInterpreter,
67+
frame: interp.Frame,
68+
stmt: qubit.MeasureAndReset,
69+
):
70+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
71+
result = []
72+
for qbit in qubits:
73+
if qbit.is_active():
74+
result.append(qbit.sim_reg.m(qbit.addr))
75+
else:
76+
result.append(None)
77+
qbit.sim_reg.force_m(qbit.addr, 0)
78+
79+
return (ilist.IList(result),)
80+
81+
@interp.impl(qubit.Reset)
82+
def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset):
83+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
84+
for qbit in qubits:
85+
qbit.sim_reg.force_m(qbit.addr, 0)

0 commit comments

Comments
 (0)