Skip to content

Commit 05c8944

Browse files
committed
Implement Sp/Sn runtime
1 parent e55a54e commit 05c8944

File tree

3 files changed

+73
-16
lines changed

3 files changed

+73
-16
lines changed

src/bloqade/pyqrack/squin/op.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from bloqade.pyqrack.base import PyQrackInterpreter
55

66
from .runtime import (
7+
SnRuntime,
8+
SpRuntime,
79
RotRuntime,
810
KronRuntime,
911
MultRuntime,
@@ -114,10 +116,10 @@ def projector(
114116
state = isinstance(stmt, op.stmts.P1)
115117
return (ProjectorRuntime(to_state=state),)
116118

117-
@interp.impl(op.stmts.Sn)
118-
def sn(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sn):
119-
raise NotImplementedError()
120-
121119
@interp.impl(op.stmts.Sp)
122120
def sp(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sp):
123-
raise NotImplementedError()
121+
return (SpRuntime(),)
122+
123+
@interp.impl(op.stmts.Sn)
124+
def sn(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sn):
125+
return (SnRuntime(),)

src/bloqade/pyqrack/squin/runtime.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,14 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
141141

142142

143143
@dataclass(frozen=True)
144-
class PhaseOpRuntime(OperatorRuntimeABC):
145-
theta: float
146-
global_: bool
147-
148-
def mat(self, adjoint: bool):
149-
sign = (-1) ** (not adjoint)
150-
phase = np.exp(sign * 1j * self.theta)
151-
return [self.global_ * phase, 0, 0, phase]
144+
class MtrxOpRuntime(OperatorRuntimeABC):
145+
def mat(self, adjoint: bool) -> list[complex]:
146+
raise NotImplementedError("Override this method in the subclass!")
152147

153148
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
154149
target = qubits[-1]
155-
target.sim_reg.mtrx(self.mat(adjoint=adjoint), target.addr)
150+
m = self.mat(adjoint=adjoint)
151+
target.sim_reg.mtrx(m, target.addr)
156152

157153
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
158154
target = qubits[-1]
@@ -163,6 +159,35 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
163159
target.sim_reg.mcmtrx(ctrls, m, target.addr)
164160

165161

162+
@dataclass(frozen=True)
163+
class SpRuntime(MtrxOpRuntime):
164+
def mat(self, adjoint: bool) -> list[complex]:
165+
if adjoint:
166+
return [0, 0, 1, 0]
167+
else:
168+
return [0, 1, 0, 0]
169+
170+
171+
@dataclass(frozen=True)
172+
class SnRuntime(MtrxOpRuntime):
173+
def mat(self, adjoint: bool) -> list[complex]:
174+
if adjoint:
175+
return [0, 1, 0, 0]
176+
else:
177+
return [0, 0, 1, 0]
178+
179+
180+
@dataclass(frozen=True)
181+
class PhaseOpRuntime(MtrxOpRuntime):
182+
theta: float
183+
global_: bool
184+
185+
def mat(self, adjoint: bool) -> list[complex]:
186+
sign = (-1) ** (not adjoint)
187+
phase = np.exp(sign * 1j * self.theta)
188+
return [self.global_ * phase, 0, 0, phase]
189+
190+
166191
@dataclass(frozen=True)
167192
class RotRuntime(OperatorRuntimeABC):
168193
axis: OperatorRuntimeABC

test/pyqrack/test_squin.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,42 @@ def main():
170170
assert result == [0]
171171

172172

173+
def test_sp():
174+
@squin.kernel
175+
def main():
176+
q = squin.qubit.new(1)
177+
sp = squin.op.spin_p()
178+
squin.qubit.apply(sp, q)
179+
return q
180+
181+
target = PyQrack(1)
182+
result = target.run(main)
183+
assert isinstance(result, ilist.IList)
184+
assert isinstance(qubit := result[0], PyQrackQubit)
185+
186+
assert qubit.sim_reg.out_ket() == [0, 0]
187+
188+
@squin.kernel
189+
def main2():
190+
q = squin.qubit.new(1)
191+
sn = squin.op.spin_n()
192+
sp = squin.op.spin_p()
193+
squin.qubit.apply(sn, q)
194+
squin.qubit.apply(sp, q)
195+
return squin.qubit.measure(q)
196+
197+
target = PyQrack(1)
198+
result = target.run(main2)
199+
assert result == [0]
200+
201+
173202
# TODO: remove
174203
# test_qubit()
175204
# test_x()
176205
# test_basic_ops("x")
177206
# test_cx()
178207
# test_mult()
179208
# test_kron()
180-
test_scale()
181-
test_phase()
209+
# test_scale()
210+
# test_phase()
211+
test_sp()

0 commit comments

Comments
 (0)