Skip to content

Commit 0246a0f

Browse files
committed
Change operator runtime implementation
1 parent 5116ada commit 0246a0f

File tree

3 files changed

+30
-28
lines changed

3 files changed

+30
-28
lines changed

src/bloqade/pyqrack/squin/op.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# from bloqade.pyqrack.reg import QubitState, PyQrackQubit
66
from bloqade.pyqrack.base import PyQrackInterpreter
77

8-
from .runtime import IdentityRuntime, OperatorRuntime, ProjectorRuntime
8+
from .runtime import ControlRuntime, IdentityRuntime, OperatorRuntime, ProjectorRuntime
99

1010
# from kirin.dialects import ilist
1111

@@ -49,10 +49,9 @@ def control(
4949
op = frame.get(stmt.op)
5050
n_controls = stmt.n_controls
5151
# FIXME: the method name here is dirty
52-
rt = OperatorRuntime(
52+
rt = ControlRuntime(
5353
method_name="mc" + op.method_name,
54-
target_index=n_controls,
55-
ctrl_index=list(range(n_controls)),
54+
n_controls=n_controls,
5655
)
5756
return (rt,)
5857

@@ -66,7 +65,7 @@ def control(
6665
def identity(
6766
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity
6867
):
69-
return (IdentityRuntime(target_index=0, sites=stmt.sites),)
68+
return (IdentityRuntime(sites=stmt.sites),)
7069

7170
# @interp.impl(op.stmts.PhaseOp)
7271
# def phaseop(
@@ -112,7 +111,7 @@ def operator(
112111
op.stmts.X | op.stmts.Y | op.stmts.Z | op.stmts.H | op.stmts.S | op.stmts.T
113112
),
114113
):
115-
return (OperatorRuntime(method_name=stmt.name.lower(), target_index=0),)
114+
return (OperatorRuntime(method_name=stmt.name.lower()),)
116115

117116
@interp.impl(op.stmts.P0)
118117
@interp.impl(op.stmts.P1)
@@ -123,7 +122,7 @@ def projector(
123122
stmt: op.stmts.P0 | op.stmts.P1,
124123
):
125124
state = isinstance(stmt, op.stmts.P1)
126-
return (ProjectorRuntime(to_state=state, target_index=0),)
125+
return (ProjectorRuntime(to_state=state),)
127126

128127
@interp.impl(op.stmts.Sn)
129128
def sn(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sn):

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
2727
def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply):
2828
operator: OperatorRuntimeABC = frame.get(stmt.operator)
2929
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
30-
operator.apply(qubits=qubits)
30+
operator.apply(*qubits)
3131

3232
@interp.impl(qubit.Measure)
3333
def measure(
Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
from typing import Any, Optional
21
from dataclasses import dataclass
32

4-
from kirin.dialects import ilist
5-
63
from bloqade.pyqrack import PyQrackQubit
74

85

96
@dataclass
107
class OperatorRuntimeABC:
11-
target_index: int
12-
13-
def apply(self, qubits: ilist.IList[PyQrackQubit, Any]) -> None:
8+
def apply(self, *qubits: PyQrackQubit) -> None:
149
raise NotImplementedError(
1510
"Operator runtime base class should not be called directly, override the method"
1611
)
@@ -19,18 +14,27 @@ def apply(self, qubits: ilist.IList[PyQrackQubit, Any]) -> None:
1914
@dataclass
2015
class OperatorRuntime(OperatorRuntimeABC):
2116
method_name: str
22-
ctrl_index: Optional[list[int]] = None
2317

2418
def apply(
2519
self,
26-
qubits: ilist.IList[PyQrackQubit, Any],
27-
):
28-
target_qubit = qubits[self.target_index]
29-
if self.ctrl_index is not None:
30-
ctrls = [qubits[i].addr for i in self.ctrl_index]
31-
getattr(target_qubit.sim_reg, self.method_name)(ctrls, target_qubit.addr)
32-
else:
33-
getattr(target_qubit.sim_reg, self.method_name)(target_qubit.addr)
20+
*qubits: PyQrackQubit,
21+
) -> None:
22+
getattr(qubits[-1].sim_reg, self.method_name)(qubits[-1].addr)
23+
24+
25+
@dataclass
26+
class ControlRuntime(OperatorRuntimeABC):
27+
method_name: str
28+
n_controls: int
29+
30+
def apply(
31+
self,
32+
*qubits: PyQrackQubit,
33+
) -> None:
34+
# NOTE: this is a bit odd, since you can "skip" qubits by making n_controls < len(qubits)
35+
ctrls = [qbit.addr for qbit in qubits[: self.n_controls]]
36+
target = qubits[-1]
37+
getattr(target.sim_reg, self.method_name)(ctrls, target.addr)
3438

3539

3640
@dataclass
@@ -39,16 +43,15 @@ class ProjectorRuntime(OperatorRuntimeABC):
3943

4044
def apply(
4145
self,
42-
qubits: ilist.IList[PyQrackQubit, Any],
43-
):
44-
target_qubit = qubits[self.target_index]
45-
target_qubit.sim_reg.force_m(target_qubit.addr, self.to_state)
46+
*qubits: PyQrackQubit,
47+
) -> None:
48+
qubits[-1].sim_reg.force_m(qubits[-1].addr, self.to_state)
4649

4750

4851
@dataclass
4952
class IdentityRuntime(OperatorRuntimeABC):
5053
# TODO: do we even need sites? The apply never does anything
5154
sites: int
5255

53-
def apply(self, qubits: ilist.IList[PyQrackQubit, Any]):
56+
def apply(self, *qubits: PyQrackQubit) -> None:
5457
pass

0 commit comments

Comments
 (0)