Skip to content

Commit a328e07

Browse files
committed
Fix method name for adjoints
1 parent 9750e06 commit a328e07

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

src/bloqade/pyqrack/squin/runtime.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,24 @@ def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
2020
class OperatorRuntime(OperatorRuntimeABC):
2121
method_name: str
2222

23+
def get_method_name(self, adjoint: bool, control: bool) -> str:
24+
method_name = ""
25+
if control:
26+
method_name += "mc"
27+
28+
if adjoint and self.method_name in ("s", "t"):
29+
method_name += "adj"
30+
31+
return method_name + self.method_name
32+
2333
def apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
24-
method_name = self.method_name
25-
if adjoint:
26-
method_name = "adj" + method_name
34+
method_name = self.get_method_name(adjoint=adjoint, control=False)
2735
getattr(qubits[0].sim_reg, method_name)(qubits[0].addr)
2836

2937
def control_apply(self, *qubits: PyQrackQubit, adjoint: bool = False) -> None:
3038
ctrls = [qbit.addr for qbit in qubits[:-1]]
3139
target = qubits[-1]
32-
method_name = "mc"
33-
if adjoint:
34-
method_name += "adj"
35-
method_name += self.method_name
40+
method_name = self.get_method_name(adjoint=adjoint, control=True)
3641
getattr(target.sim_reg, method_name)(target.addr, ctrls)
3742

3843

test/pyqrack/test_squin.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,20 @@ def main2():
199199
assert result == [0]
200200

201201

202+
def test_adjoint():
203+
@squin.kernel
204+
def main():
205+
q = squin.qubit.new(1)
206+
x = squin.op.x()
207+
xadj = squin.op.adjoint(x)
208+
squin.qubit.apply(xadj, q)
209+
return squin.qubit.measure(q)
210+
211+
target = PyQrack(1)
212+
result = target.run(main)
213+
assert result == [1]
214+
215+
202216
# TODO: remove
203217
# test_qubit()
204218
# test_x()
@@ -208,4 +222,5 @@ def main2():
208222
# test_kron()
209223
# test_scale()
210224
# test_phase()
211-
test_sp()
225+
# test_sp()
226+
# test_adjoint()

0 commit comments

Comments
 (0)