Skip to content

Commit 0d8fae8

Browse files
committed
Runtime for phase and shift operators
1 parent 3eccad7 commit 0d8fae8

File tree

3 files changed

+79
-44
lines changed

3 files changed

+79
-44
lines changed

src/bloqade/pyqrack/squin/op.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from bloqade.pyqrack.base import PyQrackInterpreter
55

66
from .runtime import (
7+
RotRuntime,
78
KronRuntime,
89
MultRuntime,
910
ScaleRuntime,
1011
AdjointRuntime,
1112
ControlRuntime,
13+
PhaseOpRuntime,
1214
IdentityRuntime,
1315
OperatorRuntime,
1416
ProjectorRuntime,
@@ -61,47 +63,29 @@ def control(
6163
)
6264
return (rt,)
6365

64-
# @interp.impl(op.stmts.Rot)
65-
# def rot(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Rot):
66-
# axis: ir.SSAValue = info.argument(OpType)
67-
# angle: ir.SSAValue = info.argument(types.Float)
68-
# result: ir.ResultValue = info.result(OpType)
66+
@interp.impl(op.stmts.Rot)
67+
def rot(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Rot):
68+
axis = frame.get(stmt.axis)
69+
angle = frame.get(stmt.angle)
70+
return (RotRuntime(axis, angle),)
6971

7072
@interp.impl(op.stmts.Identity)
7173
def identity(
7274
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity
7375
):
7476
return (IdentityRuntime(sites=stmt.sites),)
7577

76-
# @interp.impl(op.stmts.PhaseOp)
77-
# def phaseop(
78-
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.PhaseOp
79-
# ):
80-
# """
81-
# A phase operator.
82-
83-
# $$
84-
# PhaseOp(theta) = e^{i \theta} I
85-
# $$
86-
# """
87-
88-
# theta: ir.SSAValue = info.argument(types.Float)
89-
# result: ir.ResultValue = info.result(OpType)
90-
91-
# @interp.impl(op.stmts.ShiftOp)
92-
# def shiftop(
93-
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.ShiftOp
94-
# ):
95-
# """
96-
# A phase shift operator.
97-
98-
# $$
99-
# Shift(theta) = \\begin{bmatrix} 1 & 0 \\\\ 0 & e^{i \\theta} \\end{bmatrix}
100-
# $$
101-
# """
102-
103-
# theta: ir.SSAValue = info.argument(types.Float)
104-
# result: ir.ResultValue = info.result(OpType)
78+
@interp.impl(op.stmts.PhaseOp)
79+
@interp.impl(op.stmts.ShiftOp)
80+
def phaseop(
81+
self,
82+
interp: PyQrackInterpreter,
83+
frame: interp.Frame,
84+
stmt: op.stmts.PhaseOp | op.stmts.ShiftOp,
85+
):
86+
theta = frame.get(stmt.theta)
87+
global_ = isinstance(stmt, op.stmts.PhaseOp)
88+
return (PhaseOpRuntime(theta, global_=global_),)
10589

10690
@interp.impl(op.stmts.X)
10791
@interp.impl(op.stmts.Y)

src/bloqade/pyqrack/squin/runtime.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from bloqade.pyqrack import PyQrackQubit
66

77

8-
@dataclass
8+
@dataclass(frozen=True)
99
class OperatorRuntimeABC:
1010
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
1111
raise NotImplementedError(
@@ -16,7 +16,7 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
1616
raise NotImplementedError(f"Can't apply controlled version of {self}")
1717

1818

19-
@dataclass
19+
@dataclass(frozen=True)
2020
class OperatorRuntime(OperatorRuntimeABC):
2121
method_name: str
2222

@@ -36,7 +36,7 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
3636
getattr(target.sim_reg, method_name)(target.addr, ctrls)
3737

3838

39-
@dataclass
39+
@dataclass(frozen=True)
4040
class ControlRuntime(OperatorRuntimeABC):
4141
op: OperatorRuntimeABC
4242
n_controls: int
@@ -48,7 +48,7 @@ def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
4848
self.op.control_apply(target, *ctrls, adjoint=adjoint)
4949

5050

51-
@dataclass
51+
@dataclass(frozen=True)
5252
class ProjectorRuntime(OperatorRuntimeABC):
5353
to_state: bool
5454

@@ -62,7 +62,7 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
6262
target.sim_reg.mcmtrx(ctrls, m, target.addr)
6363

6464

65-
@dataclass
65+
@dataclass(frozen=True)
6666
class IdentityRuntime(OperatorRuntimeABC):
6767
# TODO: do we even need sites? The apply never does anything
6868
sites: int
@@ -74,7 +74,7 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
7474
pass
7575

7676

77-
@dataclass
77+
@dataclass(frozen=True)
7878
class MultRuntime(OperatorRuntimeABC):
7979
lhs: OperatorRuntimeABC
8080
rhs: OperatorRuntimeABC
@@ -97,7 +97,7 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
9797
self.lhs.control_apply(*qubits, adjoint=adjoint)
9898

9999

100-
@dataclass
100+
@dataclass(frozen=True)
101101
class KronRuntime(OperatorRuntimeABC):
102102
lhs: OperatorRuntimeABC
103103
rhs: OperatorRuntimeABC
@@ -107,7 +107,7 @@ def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
107107
self.rhs.apply(qubits[1], adjoint=adjoint)
108108

109109

110-
@dataclass
110+
@dataclass(frozen=True)
111111
class ScaleRuntime(OperatorRuntimeABC):
112112
op: OperatorRuntimeABC
113113
factor: complex
@@ -140,7 +140,37 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
140140
target.sim_reg.mcmtrx(ctrls, m, target.addr)
141141

142142

143-
@dataclass
143+
@dataclass(frozen=True)
144+
class PhaseOpRuntime(OperatorRuntimeABC):
145+
theta: float
146+
global_: bool
147+
148+
def mat(self, adjoint: bool):
149+
sign = (-1) ** (not adjoint)
150+
phase = np.exp(sign * 1j * self.theta)
151+
return [self.global_ * phase, 0, 0, phase]
152+
153+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
154+
target = qubits[-1]
155+
target.sim_reg.mtrx(self.mat(adjoint=adjoint), target.addr)
156+
157+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
158+
target = qubits[-1]
159+
ctrls = [qbit.addr for qbit in qubits[:-1]]
160+
161+
m = self.mat(adjoint=adjoint)
162+
163+
target.sim_reg.mcmtrx(ctrls, m, target.addr)
164+
165+
166+
@dataclass(frozen=True)
167+
class RotRuntime(OperatorRuntimeABC):
168+
axis: OperatorRuntimeABC
169+
angle: float
170+
# TODO: how does this work?
171+
172+
173+
@dataclass(frozen=True)
144174
class AdjointRuntime(OperatorRuntimeABC):
145175
op: OperatorRuntimeABC
146176

test/pyqrack/test_squin.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,32 @@ def main():
150150
assert result == [1]
151151

152152

153+
def test_phase():
154+
@squin.kernel
155+
def main():
156+
q = squin.qubit.new(1)
157+
h = squin.op.h()
158+
squin.qubit.apply(h, q)
159+
160+
# rotate local phase by pi/2
161+
p = squin.op.shift(math.pi / 2)
162+
squin.qubit.apply(p, q)
163+
164+
# the next hadamard should rotate it back to 0
165+
squin.qubit.apply(h, q)
166+
return squin.qubit.measure(q)
167+
168+
target = PyQrack(1)
169+
result = target.run(main)
170+
assert result == [0]
171+
172+
153173
# TODO: remove
154174
# test_qubit()
155175
# test_x()
156176
# test_basic_ops("x")
157177
# test_cx()
158178
# test_mult()
159179
# test_kron()
160-
# test_scale()
180+
test_scale()
181+
test_phase()

0 commit comments

Comments
 (0)