Skip to content

Commit 3eccad7

Browse files
committed
Rework impl of control to work with all supported operators
1 parent 418208b commit 3eccad7

File tree

2 files changed

+96
-30
lines changed

2 files changed

+96
-30
lines changed

src/bloqade/pyqrack/squin/op.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
KronRuntime,
88
MultRuntime,
99
ScaleRuntime,
10+
AdjointRuntime,
1011
ControlRuntime,
1112
IdentityRuntime,
1213
OperatorRuntime,
@@ -33,13 +34,12 @@ def mult(
3334
rhs = frame.get(stmt.rhs)
3435
return (MultRuntime(lhs, rhs),)
3536

36-
# @interp.impl(op.stmts.Adjoint)
37-
# def adjoint(
38-
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Adjoint
39-
# ):
40-
# is_unitary: bool = info.attribute(default=False)
41-
# op: ir.SSAValue = info.argument(OpType)
42-
# result: ir.ResultValue = info.result(OpType)
37+
@interp.impl(op.stmts.Adjoint)
38+
def adjoint(
39+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Adjoint
40+
):
41+
op = frame.get(stmt.op)
42+
return (AdjointRuntime(op),)
4343

4444
@interp.impl(op.stmts.Scale)
4545
def scale(
@@ -55,9 +55,8 @@ def control(
5555
):
5656
op = frame.get(stmt.op)
5757
n_controls = stmt.n_controls
58-
# FIXME: the method name here is dirty
5958
rt = ControlRuntime(
60-
method_name="mc" + op.method_name,
59+
op=op,
6160
n_controls=n_controls,
6261
)
6362
return (rt,)
Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,76 @@
11
from dataclasses import dataclass
22

3+
import numpy as np
4+
35
from bloqade.pyqrack import PyQrackQubit
46

57

68
@dataclass
79
class OperatorRuntimeABC:
8-
def apply(self, *qubits: PyQrackQubit) -> None:
10+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
911
raise NotImplementedError(
1012
"Operator runtime base class should not be called directly, override the method"
1113
)
1214

15+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
16+
raise NotImplementedError(f"Can't apply controlled version of {self}")
17+
1318

1419
@dataclass
1520
class OperatorRuntime(OperatorRuntimeABC):
1621
method_name: str
1722

18-
def apply(self, *qubits: PyQrackQubit) -> None:
19-
getattr(qubits[0].sim_reg, self.method_name)(qubits[0].addr)
23+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
24+
method_name = self.method_name
25+
if adjoint:
26+
method_name = "adj" + method_name
27+
getattr(qubits[0].sim_reg, method_name)(qubits[0].addr)
28+
29+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
30+
ctrls = [qbit.addr for qbit in qubits[:-1]]
31+
target = qubits[-1]
32+
method_name = "mc"
33+
if adjoint:
34+
method_name += "adj"
35+
method_name += self.method_name
36+
getattr(target.sim_reg, method_name)(target.addr, ctrls)
2037

2138

2239
@dataclass
2340
class ControlRuntime(OperatorRuntimeABC):
24-
method_name: str
41+
op: OperatorRuntimeABC
2542
n_controls: int
2643

27-
def apply(self, *qubits: PyQrackQubit) -> None:
44+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
2845
# NOTE: this is a bit odd, since you can "skip" qubits by making n_controls < len(qubits)
29-
ctrls = [qbit.addr for qbit in qubits[: self.n_controls]]
46+
ctrls = qubits[: self.n_controls]
3047
target = qubits[-1]
31-
getattr(target.sim_reg, self.method_name)(ctrls, target.addr)
48+
self.op.control_apply(target, *ctrls, adjoint=adjoint)
3249

3350

3451
@dataclass
3552
class ProjectorRuntime(OperatorRuntimeABC):
3653
to_state: bool
3754

38-
def apply(self, *qubits: PyQrackQubit) -> None:
39-
qubits[-1].sim_reg.force_m(qubits[-1].addr, self.to_state)
55+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
56+
qubits[0].sim_reg.force_m(qubits[0].addr, self.to_state)
57+
58+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
59+
m = [not self.to_state, 0, 0, self.to_state]
60+
target = qubits[-1]
61+
ctrls = [qbit.addr for qbit in qubits[:-1]]
62+
target.sim_reg.mcmtrx(ctrls, m, target.addr)
4063

4164

4265
@dataclass
4366
class IdentityRuntime(OperatorRuntimeABC):
4467
# TODO: do we even need sites? The apply never does anything
4568
sites: int
4669

47-
def apply(self, *qubits: PyQrackQubit) -> None:
70+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
71+
pass
72+
73+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
4874
pass
4975

5076

@@ -53,32 +79,73 @@ class MultRuntime(OperatorRuntimeABC):
5379
lhs: OperatorRuntimeABC
5480
rhs: OperatorRuntimeABC
5581

56-
def apply(self, *qubits: PyQrackQubit) -> None:
57-
self.rhs.apply(*qubits)
58-
self.lhs.apply(*qubits)
82+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
83+
if adjoint:
84+
# NOTE: inverted order
85+
self.lhs.apply(*qubits, adjoint=adjoint)
86+
self.rhs.apply(*qubits, adjoint=adjoint)
87+
else:
88+
self.rhs.apply(*qubits)
89+
self.lhs.apply(*qubits)
90+
91+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
92+
if adjoint:
93+
self.lhs.control_apply(*qubits, adjoint=adjoint)
94+
self.rhs.control_apply(*qubits, adjoint=adjoint)
95+
else:
96+
self.rhs.control_apply(*qubits, adjoint=adjoint)
97+
self.lhs.control_apply(*qubits, adjoint=adjoint)
5998

6099

61100
@dataclass
62101
class KronRuntime(OperatorRuntimeABC):
63102
lhs: OperatorRuntimeABC
64103
rhs: OperatorRuntimeABC
65104

66-
def apply(self, *qubits: PyQrackQubit) -> None:
67-
self.lhs.apply(qubits[0])
68-
self.rhs.apply(qubits[1])
105+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
106+
self.lhs.apply(qubits[0], adjoint=adjoint)
107+
self.rhs.apply(qubits[1], adjoint=adjoint)
69108

70109

71110
@dataclass
72111
class ScaleRuntime(OperatorRuntimeABC):
73112
op: OperatorRuntimeABC
74113
factor: complex
75114

76-
def apply(self, *qubits: PyQrackQubit) -> None:
77-
target = qubits[0]
78-
self.op.apply(target)
115+
def mat(self, adjoint: bool):
116+
if adjoint:
117+
return [np.conj(self.factor), 0, 0, self.factor]
118+
else:
119+
return [self.factor, 0, 0, self.factor]
120+
121+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
122+
self.op.apply(*qubits, adjoint=adjoint)
123+
124+
target = qubits[-1]
79125

80126
# NOTE: just factor * eye(2)
81-
mat = [self.factor, 0, 0, self.factor]
127+
m = self.mat(adjoint)
82128

83129
# TODO: output seems to always be normalized -- no-op?
84-
target.sim_reg.mtrx(mat, target.addr)
130+
target.sim_reg.mtrx(m, target.addr)
131+
132+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
133+
self.op.control_apply(*qubits, adjoint=adjoint)
134+
135+
target = qubits[-1]
136+
ctrls = [qbit.addr for qbit in qubits[:-1]]
137+
138+
m = self.mat(adjoint=adjoint)
139+
140+
target.sim_reg.mcmtrx(ctrls, m, target.addr)
141+
142+
143+
@dataclass
144+
class AdjointRuntime(OperatorRuntimeABC):
145+
op: OperatorRuntimeABC
146+
147+
def apply(self, *qubits: PyQrackQubit, adjoint: bool = True) -> None:
148+
self.op.apply(*qubits, adjoint=adjoint)
149+
150+
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = True) -> None:
151+
self.op.control_apply(*qubits, adjoint=adjoint)

0 commit comments

Comments
 (0)