Skip to content

Commit c01d377

Browse files
committed
Implement mult runtime -- TODO: don't wrap mult
1 parent 300364e commit c01d377

File tree

4 files changed

+54
-22
lines changed

4 files changed

+54
-22
lines changed

src/bloqade/pyqrack/squin/op.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
# from bloqade.pyqrack.reg import QubitState, PyQrackQubit
66
from bloqade.pyqrack.base import PyQrackInterpreter
77

8-
from .runtime import ControlRuntime, IdentityRuntime, OperatorRuntime, ProjectorRuntime
8+
from .runtime import (
9+
MultRuntime,
10+
ControlRuntime,
11+
IdentityRuntime,
12+
OperatorRuntime,
13+
ProjectorRuntime,
14+
)
915

1016
# from kirin.dialects import ilist
1117

@@ -19,11 +25,13 @@ class PyQrackMethods(interp.MethodTable):
1925
# ):
2026
# is_unitary: bool = info.attribute(default=False)
2127

22-
# @interp.impl(op.stmts.Mult)
23-
# def mult(
24-
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Mult
25-
# ):
26-
# is_unitary: bool = info.attribute(default=False)
28+
@interp.impl(op.stmts.Mult)
29+
def mult(
30+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Mult
31+
):
32+
lhs = frame.get(stmt.lhs)
33+
rhs = frame.get(stmt.rhs)
34+
return (MultRuntime(lhs, rhs),)
2735

2836
# @interp.impl(op.stmts.Adjoint)
2937
# def adjoint(

src/bloqade/pyqrack/squin/runtime.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ def apply(self, *qubits: PyQrackQubit) -> None:
1515
class OperatorRuntime(OperatorRuntimeABC):
1616
method_name: str
1717

18-
def apply(
19-
self,
20-
*qubits: PyQrackQubit,
21-
) -> None:
18+
def apply(self, *qubits: PyQrackQubit) -> None:
2219
getattr(qubits[-1].sim_reg, self.method_name)(qubits[-1].addr)
2320

2421

@@ -27,10 +24,7 @@ class ControlRuntime(OperatorRuntimeABC):
2724
method_name: str
2825
n_controls: int
2926

30-
def apply(
31-
self,
32-
*qubits: PyQrackQubit,
33-
) -> None:
27+
def apply(self, *qubits: PyQrackQubit) -> None:
3428
# NOTE: this is a bit odd, since you can "skip" qubits by making n_controls < len(qubits)
3529
ctrls = [qbit.addr for qbit in qubits[: self.n_controls]]
3630
target = qubits[-1]
@@ -41,10 +35,7 @@ def apply(
4135
class ProjectorRuntime(OperatorRuntimeABC):
4236
to_state: bool
4337

44-
def apply(
45-
self,
46-
*qubits: PyQrackQubit,
47-
) -> None:
38+
def apply(self, *qubits: PyQrackQubit) -> None:
4839
qubits[-1].sim_reg.force_m(qubits[-1].addr, self.to_state)
4940

5041

@@ -55,3 +46,13 @@ class IdentityRuntime(OperatorRuntimeABC):
5546

5647
def apply(self, *qubits: PyQrackQubit) -> None:
5748
pass
49+
50+
51+
@dataclass
52+
class MultRuntime(OperatorRuntimeABC):
53+
lhs: OperatorRuntimeABC
54+
rhs: OperatorRuntimeABC
55+
56+
def apply(self, *qubits: PyQrackQubit) -> None:
57+
self.rhs.apply(*qubits)
58+
self.lhs.apply(*qubits)

src/bloqade/squin/op/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
1212

1313

14+
# FIXME: should we just rewrite the py.binop.mult instead?
15+
@_wraps(stmts.Mult)
16+
def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
17+
18+
1419
@_wraps(stmts.Adjoint)
1520
def adjoint(op: types.Op) -> types.Op: ...
1621

test/pyqrack/test_squin.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,26 @@ def main():
101101
target.run(main)
102102

103103

104+
def test_mult():
105+
@squin.kernel
106+
def main():
107+
q = squin.qubit.new(1)
108+
x = squin.op.x()
109+
id = squin.op.mult(x, x)
110+
squin.qubit.apply(id, q)
111+
return squin.qubit.measure(q)
112+
113+
main.print()
114+
115+
target = PyQrack(1)
116+
result = target.run(main)
117+
118+
assert result == [0]
119+
120+
104121
# TODO: remove
105-
test_qubit()
106-
test_x()
107-
test_basic_ops("x")
108-
test_cx()
122+
# test_qubit()
123+
# test_x()
124+
# test_basic_ops("x")
125+
# test_cx()
126+
# test_mult()

0 commit comments

Comments
 (0)