Skip to content

Commit eef3b91

Browse files
committed
Start implementing operator runtime
1 parent ca7f347 commit eef3b91

File tree

8 files changed

+258
-42
lines changed

8 files changed

+258
-42
lines changed

src/bloqade/pyqrack/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414
# NOTE: The following import is for registering the method tables
1515
from .noise import native as native
1616
from .qasm2 import uop as uop, core as core, glob as glob, parallel as parallel
17-
from .squin import qubit as qubit
17+
from .squin import op as op, qubit as qubit
1818
from .target import PyQrack as PyQrack

src/bloqade/pyqrack/squin/__init__.py

Whitespace-only changes.

src/bloqade/pyqrack/squin/op.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from kirin import interp
2+
3+
from bloqade.squin import op
4+
5+
# from bloqade.pyqrack.reg import QubitState, PyQrackQubit
6+
from bloqade.pyqrack.base import PyQrackInterpreter
7+
8+
from .runtime import IdentityRuntime, OperatorRuntime, ProjectorRuntime
9+
10+
# from kirin.dialects import ilist
11+
12+
13+
@op.dialect.register(key="pyqrack")
14+
class PyQrackMethods(interp.MethodTable):
15+
16+
# @interp.impl(op.stmts.Kron)
17+
# def kron(
18+
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Kron
19+
# ):
20+
# is_unitary: bool = info.attribute(default=False)
21+
22+
# @interp.impl(op.stmts.Mult)
23+
# def mult(
24+
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Mult
25+
# ):
26+
# is_unitary: bool = info.attribute(default=False)
27+
28+
# @interp.impl(op.stmts.Adjoint)
29+
# def adjoint(
30+
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Adjoint
31+
# ):
32+
# is_unitary: bool = info.attribute(default=False)
33+
# op: ir.SSAValue = info.argument(OpType)
34+
# result: ir.ResultValue = info.result(OpType)
35+
36+
# @interp.impl(op.stmts.Scale)
37+
# def scale(
38+
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Scale
39+
# ):
40+
# is_unitary: bool = info.attribute(default=False)
41+
# op: ir.SSAValue = info.argument(OpType)
42+
# factor: ir.SSAValue = info.argument(Complex)
43+
# result: ir.ResultValue = info.result(OpType)
44+
45+
@interp.impl(op.stmts.Control)
46+
def control(
47+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Control
48+
):
49+
op = frame.get(stmt.op)
50+
n_controls = stmt.n_controls
51+
# FIXME: the method name here is dirty
52+
rt = OperatorRuntime(
53+
method_name="mc" + op.method_name,
54+
target_index=n_controls,
55+
ctrl_index=list(range(n_controls)),
56+
)
57+
return (rt,)
58+
59+
# @interp.impl(op.stmts.Rot)
60+
# def rot(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Rot):
61+
# axis: ir.SSAValue = info.argument(OpType)
62+
# angle: ir.SSAValue = info.argument(types.Float)
63+
# result: ir.ResultValue = info.result(OpType)
64+
65+
@interp.impl(op.stmts.Identity)
66+
def identity(
67+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity
68+
):
69+
return (IdentityRuntime(target_index=0, sites=stmt.sites),)
70+
71+
# @interp.impl(op.stmts.PhaseOp)
72+
# def phaseop(
73+
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.PhaseOp
74+
# ):
75+
# """
76+
# A phase operator.
77+
78+
# $$
79+
# PhaseOp(theta) = e^{i \theta} I
80+
# $$
81+
# """
82+
83+
# theta: ir.SSAValue = info.argument(types.Float)
84+
# result: ir.ResultValue = info.result(OpType)
85+
86+
# @interp.impl(op.stmts.ShiftOp)
87+
# def shiftop(
88+
# self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.ShiftOp
89+
# ):
90+
# """
91+
# A phase shift operator.
92+
93+
# $$
94+
# Shift(theta) = \\begin{bmatrix} 1 & 0 \\\\ 0 & e^{i \\theta} \\end{bmatrix}
95+
# $$
96+
# """
97+
98+
# theta: ir.SSAValue = info.argument(types.Float)
99+
# result: ir.ResultValue = info.result(OpType)
100+
101+
@interp.impl(op.stmts.X)
102+
@interp.impl(op.stmts.Y)
103+
@interp.impl(op.stmts.Z)
104+
@interp.impl(op.stmts.H)
105+
@interp.impl(op.stmts.S)
106+
@interp.impl(op.stmts.T)
107+
def operator(
108+
self,
109+
interp: PyQrackInterpreter,
110+
frame: interp.Frame,
111+
stmt: (
112+
op.stmts.X | op.stmts.Y | op.stmts.Z | op.stmts.H | op.stmts.S | op.stmts.T
113+
),
114+
):
115+
return (OperatorRuntime(method_name=stmt.name.lower(), target_index=0),)
116+
117+
@interp.impl(op.stmts.P0)
118+
@interp.impl(op.stmts.P1)
119+
def projector(
120+
self,
121+
interp: PyQrackInterpreter,
122+
frame: interp.Frame,
123+
stmt: op.stmts.P0 | op.stmts.P1,
124+
):
125+
state = isinstance(stmt, op.stmts.P1)
126+
return (ProjectorRuntime(to_state=state, target_index=0),)
127+
128+
@interp.impl(op.stmts.Sn)
129+
def sn(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sn):
130+
raise NotImplementedError()
131+
132+
@interp.impl(op.stmts.Sp)
133+
def sp(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sp):
134+
raise NotImplementedError()

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from bloqade.pyqrack.reg import QubitState, PyQrackQubit
88
from bloqade.pyqrack.base import PyQrackInterpreter
99

10+
from .runtime import OperatorRuntimeABC
11+
1012

1113
@qubit.dialect.register(key="pyqrack")
1214
class PyQrackMethods(interp.MethodTable):
@@ -23,10 +25,9 @@ def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
2325

2426
@interp.impl(qubit.Apply)
2527
def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
26-
# TODO
27-
# operator: ir.SSAValue = info.argument(OpType)
28-
# qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
29-
pass
28+
operator: OperatorRuntimeABC = frame.get(stmt.operator)
29+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
30+
operator.apply(qubits=qubits)
3031

3132
@interp.impl(qubit.Measure)
3233
def measure(
@@ -55,20 +56,3 @@ def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Res
5556
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
5657
for qbit in qubits:
5758
qbit.sim_reg.force_m(qbit.addr, 0)
58-
59-
# @interp.impl(glob.UGate)
60-
# def ugate(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: glob.UGate):
61-
# registers: ilist.IList[ilist.IList[PyQrackQubit, Any], Any] = frame.get(
62-
# stmt.registers
63-
# )
64-
# theta, phi, lam = (
65-
# frame.get(stmt.theta),
66-
# frame.get(stmt.phi),
67-
# frame.get(stmt.lam),
68-
# )
69-
70-
# for qreg in registers:
71-
# for qarg in qreg:
72-
# if qarg.is_active():
73-
# interp.memory.sim_reg.u(qarg.addr, theta, phi, lam)
74-
# return ()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Any, Optional
2+
from dataclasses import dataclass
3+
4+
from kirin.dialects import ilist
5+
6+
from bloqade.pyqrack import PyQrackQubit
7+
8+
9+
@dataclass
10+
class OperatorRuntimeABC:
11+
target_index: int
12+
13+
def apply(self, qubits: ilist.IList[PyQrackQubit, Any]) -> None:
14+
raise NotImplementedError(
15+
"Operator runtime base class should not be called directly, override the method"
16+
)
17+
18+
19+
@dataclass
20+
class OperatorRuntime(OperatorRuntimeABC):
21+
method_name: str
22+
ctrl_index: Optional[list[int]] = None
23+
24+
def apply(
25+
self,
26+
qubits: ilist.IList[PyQrackQubit, Any],
27+
):
28+
target_qubit = qubits[self.target_index]
29+
if self.ctrl_index is not None:
30+
ctrls = [qubits[i].addr for i in self.ctrl_index]
31+
getattr(target_qubit.sim_reg, self.method_name)(ctrls, target_qubit.addr)
32+
else:
33+
getattr(target_qubit.sim_reg, self.method_name)(target_qubit.addr)
34+
35+
36+
@dataclass
37+
class ProjectorRuntime(OperatorRuntimeABC):
38+
to_state: bool
39+
40+
def apply(
41+
self,
42+
qubits: ilist.IList[PyQrackQubit, Any],
43+
):
44+
target_qubit = qubits[self.target_index]
45+
target_qubit.sim_reg.force_m(target_qubit.addr, self.to_state)
46+
47+
48+
@dataclass
49+
class IdentityRuntime(OperatorRuntimeABC):
50+
# TODO: do we even need sites? The apply never does anything
51+
sites: int
52+
53+
def apply(self, qubits: ilist.IList[PyQrackQubit, Any]):
54+
pass

src/bloqade/squin/op/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88

99

1010
@_wraps(stmts.Kron)
11-
def kron(lhs: types.Op, rhs: types.Op, *, is_unitary: bool = False) -> types.Op: ...
11+
def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
1212

1313

1414
@_wraps(stmts.Adjoint)
15-
def adjoint(op: types.Op, *, is_unitary: bool = False) -> types.Op: ...
15+
def adjoint(op: types.Op) -> types.Op: ...
1616

1717

1818
@_wraps(stmts.Control)
19-
def control(op: types.Op, *, n_controls: int, is_unitary: bool = False) -> types.Op: ...
19+
def control(op: types.Op, *, n_controls: int) -> types.Op: ...
2020

2121

2222
@_wraps(stmts.Identity)

src/bloqade/squin/qubit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class MeasureAndReset(ir.Statement):
5050

5151
@statement(dialect=dialect)
5252
class Reset(ir.Statement):
53+
traits = frozenset({lowering.FromPythonCall()})
5354
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
5455

5556

test/pyqrack/test_squin.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import math
2+
3+
import pytest
14
from kirin.dialects import ilist
25

36
from bloqade import squin
@@ -43,23 +46,63 @@ def measure_and_reset():
4346
assert result == [0, 0, 0]
4447

4548

46-
# @squin.kernel
47-
# def main():
48-
# q = squin.qubit.new(3)
49-
# x = squin.op.x()
50-
# id = squin.op.identity(sites=2)
51-
52-
# # FIXME? Should we have a method apply(x, q, idx)?
53-
# squin.qubit.apply(squin.op.kron(x, id), q)
54-
55-
# return squin.qubit.measure(q)
56-
57-
58-
# main.print()
49+
def test_x():
50+
@squin.kernel
51+
def main():
52+
q = squin.qubit.new(1)
53+
x = squin.op.x()
54+
squin.qubit.apply(x, q)
55+
return squin.qubit.measure(q)
56+
57+
target = PyQrack(1)
58+
result = target.run(main)
59+
assert result == [1]
60+
61+
62+
@pytest.mark.parametrize(
63+
"op_name",
64+
[
65+
"x",
66+
"y",
67+
"z",
68+
"h",
69+
"s",
70+
"t",
71+
],
72+
)
73+
def test_basic_ops(op_name: str):
74+
@squin.kernel
75+
def main():
76+
q = squin.qubit.new(1)
77+
op = getattr(squin.op, op_name)()
78+
squin.qubit.apply(op, q)
79+
return q
80+
81+
target = PyQrack(1)
82+
result = target.run(main)
83+
assert isinstance(result, ilist.IList)
84+
assert isinstance(qubit := result[0], PyQrackQubit)
5985

60-
# target = PyQrack(2)
61-
# result = target.run(main)
86+
ket = qubit.sim_reg.out_ket()
87+
n = sum([abs(k) ** 2 for k in ket])
88+
assert math.isclose(n, 1, abs_tol=1e-6)
6289

6390

64-
if __name__ == "main":
65-
test_qubit()
91+
def test_cx():
92+
@squin.kernel
93+
def main():
94+
q = squin.qubit.new(2)
95+
x = squin.op.x()
96+
cx = squin.op.control(x, n_controls=1)
97+
squin.qubit.apply(cx, q)
98+
return q
99+
100+
target = PyQrack(2)
101+
target.run(main)
102+
103+
104+
# TODO: remove
105+
test_qubit()
106+
test_x()
107+
test_basic_ops("x")
108+
test_cx()

0 commit comments

Comments
 (0)