Skip to content

Commit b997c4a

Browse files
committed
Draft implementation for wire dialect
1 parent 92456df commit b997c4a

File tree

5 files changed

+102
-3
lines changed

5 files changed

+102
-3
lines changed

src/bloqade/pyqrack/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
CRegister as CRegister,
44
QubitState as QubitState,
55
Measurement as Measurement,
6+
PyQrackWire as PyQrackWire,
67
PyQrackQubit as PyQrackQubit,
78
)
89
from .base import (

src/bloqade/pyqrack/reg.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ def is_active(self) -> bool:
7070
def drop(self):
7171
"""Drop the qubit in-place."""
7272
self.state = QubitState.Lost
73+
74+
75+
@dataclass
76+
class PyQrackWire:
77+
qubit: PyQrackQubit

src/bloqade/pyqrack/squin/wire.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from kirin import interp
2+
3+
from bloqade.squin import wire
4+
from bloqade.pyqrack.reg import PyQrackWire, PyQrackQubit
5+
from bloqade.pyqrack.base import PyQrackInterpreter
6+
7+
from .runtime import OperatorRuntimeABC
8+
9+
10+
@wire.dialect.register(key="pyqrack")
11+
class PyQrackMethods(interp.MethodTable):
12+
# @interp.impl(wire.Wrap)
13+
# def wrap(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Wrap):
14+
# traits = frozenset({lowering.FromPythonCall(), WireTerminator()})
15+
# wire: ir.SSAValue = info.argument(WireType)
16+
# qubit: ir.SSAValue = info.argument(QubitType)
17+
18+
@interp.impl(wire.Unwrap)
19+
def unwrap(
20+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Unwrap
21+
):
22+
q: PyQrackQubit = frame.get(stmt.qubit)
23+
return (PyQrackWire(q),)
24+
25+
@interp.impl(wire.Apply)
26+
def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Apply):
27+
ws = stmt.inputs
28+
assert isinstance(ws, tuple)
29+
qubits: list[PyQrackQubit] = []
30+
for w in ws:
31+
assert isinstance(w, PyQrackWire)
32+
qubits.append(w.qubit)
33+
op: OperatorRuntimeABC = frame.get(stmt.operator)
34+
35+
op.apply(*qubits)
36+
37+
out_ws = [PyQrackWire(qbit) for qbit in qubits]
38+
return (out_ws,)
39+
40+
@interp.impl(wire.Measure)
41+
def measure(
42+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Measure
43+
):
44+
w: PyQrackWire = frame.get(stmt.wire)
45+
qbit = w.qubit
46+
res: int = qbit.sim_reg.m(qbit.addr)
47+
return (res,)
48+
49+
@interp.impl(wire.MeasureAndReset)
50+
def measure_and_reset(
51+
self,
52+
interp: PyQrackInterpreter,
53+
frame: interp.Frame,
54+
stmt: wire.MeasureAndReset,
55+
):
56+
w: PyQrackWire = frame.get(stmt.wire)
57+
qbit = w.qubit
58+
res: int = qbit.sim_reg.m(qbit.addr)
59+
qbit.sim_reg.force_m(qbit.addr, False)
60+
new_w = PyQrackWire(qbit)
61+
return (new_w, res)
62+
63+
@interp.impl(wire.Reset)
64+
def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: wire.Reset):
65+
w: PyQrackWire = frame.get(stmt.wire)
66+
qbit = w.qubit
67+
qbit.sim_reg.force_m(qbit.addr, False)
68+
new_w = PyQrackWire(qbit)
69+
return (new_w,)

src/bloqade/squin/wire.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
from kirin import ir, types, interp, lowering
1010
from kirin.decl import info, statement
11+
from kirin.lowering import wraps
1112

12-
from bloqade.types import QubitType
13+
from bloqade.types import Qubit, QubitType
1314

14-
from .op.types import OpType
15+
from .op.types import Op, OpType
1516

1617
# from kirin.lowering import wraps
1718

@@ -101,3 +102,11 @@ class ConstPropWire(interp.MethodTable):
101102
def apply(self, interp, frame, stmt: Apply):
102103

103104
return frame.get_values(stmt.inputs)
105+
106+
107+
@wraps(Unwrap)
108+
def unwrap(qubit: Qubit) -> Wire: ...
109+
110+
111+
@wraps(Apply)
112+
def apply(op: Op, w: Wire) -> Wire: ...

test/pyqrack/test_squin.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from kirin.dialects import ilist
55

66
from bloqade import squin
7-
from bloqade.pyqrack import PyQrack, PyQrackQubit
7+
from bloqade.pyqrack import PyQrack, PyQrackWire, PyQrackQubit
88

99

1010
def test_qubit():
@@ -311,6 +311,20 @@ def main():
311311
assert result == [1, 1, 1]
312312

313313

314+
def test_wire():
315+
@squin.wired
316+
def main():
317+
w = squin.wire.unwrap(1)
318+
x = squin.op.x()
319+
squin.wire.apply(x, w)
320+
return w
321+
322+
target = PyQrack(1)
323+
result = target.run(main)
324+
assert isinstance(result, PyQrackWire)
325+
assert result.qubit.sim_reg.out_ket() == [0, 1]
326+
327+
314328
# TODO: remove
315329
# test_qubit()
316330
# test_x()
@@ -329,3 +343,4 @@ def main():
329343
# test_broadcast()
330344
# test_u3()
331345
# test_clifford_str()
346+
# test_wire()

0 commit comments

Comments
 (0)