Skip to content

Commit 914b1fd

Browse files
committed
Implement kron runtime
1 parent 7d466a0 commit 914b1fd

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

src/bloqade/pyqrack/squin/op.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from bloqade.pyqrack.base import PyQrackInterpreter
77

88
from .runtime import (
9+
KronRuntime,
910
MultRuntime,
1011
ControlRuntime,
1112
IdentityRuntime,
@@ -19,11 +20,13 @@
1920
@op.dialect.register(key="pyqrack")
2021
class PyQrackMethods(interp.MethodTable):
2122

22-
# @interp.impl(op.stmts.Kron)
23-
# def kron(
24-
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Kron
25-
# ):
26-
# is_unitary: bool = info.attribute(default=False)
23+
@interp.impl(op.stmts.Kron)
24+
def kron(
25+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Kron
26+
):
27+
lhs = frame.get(stmt.lhs)
28+
rhs = frame.get(stmt.rhs)
29+
return (KronRuntime(lhs, rhs),)
2730

2831
@interp.impl(op.stmts.Mult)
2932
def mult(

src/bloqade/pyqrack/squin/runtime.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,15 @@ class MultRuntime(OperatorRuntimeABC):
5656
def apply(self, *qubits: PyQrackQubit) -> None:
5757
self.rhs.apply(*qubits)
5858
self.lhs.apply(*qubits)
59+
60+
61+
@dataclass
62+
class KronRuntime(OperatorRuntimeABC):
63+
lhs: OperatorRuntimeABC
64+
rhs: OperatorRuntimeABC
65+
66+
def apply(self, *qubits: PyQrackQubit) -> None:
67+
assert len(qubits) == 2
68+
qbit1, qbit2 = qubits
69+
self.lhs.apply(qbit1)
70+
self.rhs.apply(qbit2)

test/pyqrack/test_squin.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,25 @@ def main():
118118
assert result == [0]
119119

120120

121+
def test_kron():
122+
@squin.kernel
123+
def main():
124+
q = squin.qubit.new(2)
125+
x = squin.op.x()
126+
k = squin.op.kron(x, x)
127+
squin.qubit.apply(k, q)
128+
return squin.qubit.measure(q)
129+
130+
target = PyQrack(2)
131+
result = target.run(main)
132+
133+
assert result == [1, 1]
134+
135+
121136
# TODO: remove
122137
# test_qubit()
123138
# test_x()
124139
# test_basic_ops("x")
125140
# test_cx()
126141
# test_mult()
142+
# test_kron()

0 commit comments

Comments
 (0)