Skip to content

Commit 2709b00

Browse files
kaihsinjohnzl-777
andcommitted
Fix rewrite from squin -> stim with measurement (#370)
This PR 1. unify liftIfs related rewrite rule btwn qasm2 and stim 2. refactor squin to stim rewrite. 3. Mark squin noise op as pure. 4. Add ilist canonicalize passes (with two rules added) Co-authored-by: @johnzl-777 --------- Co-authored-by: John Long <[email protected]>
1 parent 7bdcb0b commit 2709b00

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1465
-158
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
from kirin.analysis import ForwardFrame, const
77
from kirin.dialects import cf, py, scf, func, ilist
88

9-
from bloqade import squin
10-
119
from .lattice import (
1210
Address,
1311
NotQubit,
1412
AddressReg,
15-
AddressWire,
1613
AddressQubit,
1714
AddressTuple,
1815
)
@@ -163,63 +160,3 @@ def for_loop(
163160
return # if terminate is Return, there is no result
164161

165162
return loop_vars
166-
167-
168-
# Address lattice elements we can work with:
169-
## NotQubit (bottom), AnyAddress (top)
170-
171-
## AddressTuple -> data: tuple[Address, ...]
172-
### Recursive type, could contain itself or other variants
173-
### This pops up in cases where you can have an IList/Tuple
174-
### That contains elements that could be other Address types
175-
176-
## AddressReg -> data: Sequence[int]
177-
### specific to creation of a register of qubits
178-
179-
## AddressQubit -> data: int
180-
### Base qubit address type
181-
182-
183-
@squin.wire.dialect.register(key="qubit.address")
184-
class SquinWireMethodTable(interp.MethodTable):
185-
186-
@interp.impl(squin.wire.Unwrap)
187-
def unwrap(
188-
self,
189-
interp_: AddressAnalysis,
190-
frame: ForwardFrame[Address],
191-
stmt: squin.wire.Unwrap,
192-
):
193-
194-
origin_qubit = frame.get(stmt.qubit)
195-
196-
if isinstance(origin_qubit, AddressQubit):
197-
return (AddressWire(origin_qubit=origin_qubit),)
198-
else:
199-
return (Address.top(),)
200-
201-
@interp.impl(squin.wire.Apply)
202-
def apply(
203-
self,
204-
interp_: AddressAnalysis,
205-
frame: ForwardFrame[Address],
206-
stmt: squin.wire.Apply,
207-
):
208-
return frame.get_values(stmt.inputs)
209-
210-
211-
@squin.qubit.dialect.register(key="qubit.address")
212-
class SquinQubitMethodTable(interp.MethodTable):
213-
214-
# This can be treated like a QRegNew impl
215-
@interp.impl(squin.qubit.New)
216-
def new(
217-
self,
218-
interp_: AddressAnalysis,
219-
frame: ForwardFrame[Address],
220-
stmt: squin.qubit.New,
221-
):
222-
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
223-
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
224-
interp_.next_address += n_qubits
225-
return (addr,)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import impls as impls
2+
from .analysis import MeasurementIDAnalysis as MeasurementIDAnalysis
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import TypeVar
2+
3+
from kirin import ir, interp
4+
from kirin.analysis import Forward, const
5+
from kirin.analysis.forward import ForwardFrame
6+
7+
from .lattice import MeasureId, NotMeasureId
8+
9+
10+
class MeasurementIDAnalysis(Forward[MeasureId]):
11+
12+
keys = ["measure_id"]
13+
lattice = MeasureId
14+
# for every kind of measurement encountered, increment this
15+
# then use this to generate the negative values for target rec indices
16+
measure_count = 0
17+
18+
# Still default to bottom,
19+
# but let constants return the softer "NoMeasureId" type from impl
20+
def eval_stmt_fallback(
21+
self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
22+
) -> tuple[MeasureId, ...]:
23+
return tuple(NotMeasureId() for _ in stmt.results)
24+
25+
def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
26+
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
27+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
28+
29+
T = TypeVar("T")
30+
31+
# Xiu-zhe (Roger) Luo came up with this in the address analysis,
32+
# reused here for convenience
33+
# TODO: Remove this function once upgrade to kirin 0.18 happens,
34+
# method is built-in to interpreter then
35+
def get_const_value(self, input_type: type[T], value: ir.SSAValue) -> T:
36+
if isinstance(hint := value.hints.get("const"), const.Value):
37+
data = hint.data
38+
if isinstance(data, input_type):
39+
return hint.data
40+
raise interp.InterpreterError(
41+
f"Expected constant value <type = {input_type}>, got {data}"
42+
)
43+
raise interp.InterpreterError(
44+
f"Expected constant value <type = {input_type}>, got {value}"
45+
)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from kirin import types as kirin_types, interp
2+
from kirin.dialects import py, scf, func, ilist
3+
4+
from bloqade.squin import wire, qubit
5+
6+
from .lattice import (
7+
AnyMeasureId,
8+
NotMeasureId,
9+
MeasureIdBool,
10+
MeasureIdTuple,
11+
InvalidMeasureId,
12+
)
13+
from .analysis import MeasurementIDAnalysis
14+
15+
## Can't do wire right now because of
16+
## unresolved RFC on return type
17+
# from bloqade.squin import wire
18+
19+
20+
@qubit.dialect.register(key="measure_id")
21+
class SquinQubit(interp.MethodTable):
22+
23+
@interp.impl(qubit.MeasureQubit)
24+
def measure_qubit(
25+
self,
26+
interp: MeasurementIDAnalysis,
27+
frame: interp.Frame,
28+
stmt: qubit.MeasureQubit,
29+
):
30+
interp.measure_count += 1
31+
return (MeasureIdBool(interp.measure_count),)
32+
33+
@interp.impl(qubit.MeasureQubitList)
34+
def measure_qubit_list(
35+
self,
36+
interp: MeasurementIDAnalysis,
37+
frame: interp.Frame,
38+
stmt: qubit.MeasureQubitList,
39+
):
40+
41+
# try to get the length of the list
42+
## "...safely assume the type inference will give you what you need"
43+
qubits_type = stmt.qubits.type
44+
# vars[0] is just the type of the elements in the ilist,
45+
# vars[1] can contain a literal with length information
46+
num_qubits = qubits_type.vars[1]
47+
if not isinstance(num_qubits, kirin_types.Literal):
48+
return (AnyMeasureId(),)
49+
50+
measure_id_bools = []
51+
for _ in range(num_qubits.data):
52+
interp.measure_count += 1
53+
measure_id_bools.append(MeasureIdBool(interp.measure_count))
54+
55+
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
56+
57+
58+
@wire.dialect.register(key="measure_id")
59+
class SquinWire(interp.MethodTable):
60+
61+
@interp.impl(wire.Measure)
62+
def measure_qubit(
63+
self,
64+
interp: MeasurementIDAnalysis,
65+
frame: interp.Frame,
66+
stmt: wire.Measure,
67+
):
68+
interp.measure_count += 1
69+
return (MeasureIdBool(interp.measure_count),)
70+
71+
72+
@ilist.dialect.register(key="measure_id")
73+
class IList(interp.MethodTable):
74+
@interp.impl(ilist.New)
75+
# Because of the way GetItem works,
76+
# A user could create an ilist of bools that
77+
# ends up being a mixture of MeasureIdBool and NotMeasureId
78+
def new_ilist(
79+
self,
80+
interp: MeasurementIDAnalysis,
81+
frame: interp.Frame,
82+
stmt: ilist.New,
83+
):
84+
85+
measure_ids_in_ilist = frame.get_values(stmt.values)
86+
return (MeasureIdTuple(data=tuple(measure_ids_in_ilist)),)
87+
88+
89+
@py.tuple.dialect.register(key="measure_id")
90+
class PyTuple(interp.MethodTable):
91+
@interp.impl(py.tuple.New)
92+
def new_tuple(
93+
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.tuple.New
94+
):
95+
measure_ids_in_tuple = frame.get_values(stmt.args)
96+
return (MeasureIdTuple(data=tuple(measure_ids_in_tuple)),)
97+
98+
99+
@py.indexing.dialect.register(key="measure_id")
100+
class PyIndexing(interp.MethodTable):
101+
@interp.impl(py.GetItem)
102+
def getitem(
103+
self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.GetItem
104+
):
105+
idx = interp.get_const_value(int, stmt.index)
106+
obj = frame.get(stmt.obj)
107+
if isinstance(obj, MeasureIdTuple):
108+
return (obj.data[idx],)
109+
# just propagate these down the line
110+
elif isinstance(obj, (AnyMeasureId, NotMeasureId)):
111+
return (obj,)
112+
else:
113+
return (InvalidMeasureId(),)
114+
115+
116+
@py.binop.dialect.register(key="measure_id")
117+
class PyBinOp(interp.MethodTable):
118+
@interp.impl(py.Add)
119+
def add(self, interp: MeasurementIDAnalysis, frame: interp.Frame, stmt: py.Add):
120+
lhs = frame.get(stmt.lhs)
121+
rhs = frame.get(stmt.rhs)
122+
123+
if isinstance(lhs, MeasureIdTuple) and isinstance(rhs, MeasureIdTuple):
124+
return (MeasureIdTuple(data=lhs.data + rhs.data),)
125+
else:
126+
return (InvalidMeasureId(),)
127+
128+
129+
@func.dialect.register(key="measure_id")
130+
class Func(interp.MethodTable):
131+
@interp.impl(func.Return)
132+
def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Return):
133+
return interp.ReturnValue(frame.get(stmt.value))
134+
135+
# taken from Address Analysis implementation from Xiu-zhe (Roger) Luo
136+
@interp.impl(
137+
func.Invoke
138+
) # we know the callee already, func.Call would mean we don't know the callee @ compile time
139+
def invoke(
140+
self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
141+
):
142+
_, ret = interp_.run_method(
143+
stmt.callee,
144+
interp_.permute_values(
145+
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
146+
),
147+
)
148+
return (ret,)
149+
150+
151+
# Just let analysis propagate through
152+
# scf, particularly IfElse
153+
@scf.dialect.register(key="measure_id")
154+
class Scf(scf.absint.Methods):
155+
pass
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import final
2+
from dataclasses import dataclass
3+
4+
from kirin.lattice import (
5+
SingletonMeta,
6+
BoundedLattice,
7+
SimpleJoinMixin,
8+
SimpleMeetMixin,
9+
)
10+
11+
# Taken directly from Kai-Hsin Wu's implementation
12+
# with minor changes to names and addition of CanMeasureId type
13+
14+
15+
@dataclass
16+
class MeasureId(
17+
SimpleJoinMixin["MeasureId"],
18+
SimpleMeetMixin["MeasureId"],
19+
BoundedLattice["MeasureId"],
20+
):
21+
22+
@classmethod
23+
def bottom(cls) -> "MeasureId":
24+
return InvalidMeasureId()
25+
26+
@classmethod
27+
def top(cls) -> "MeasureId":
28+
return AnyMeasureId()
29+
30+
31+
# Can pop up if user constructs some list containing a mixture
32+
# of bools from measure results and other places,
33+
# in which case the whole list is invalid
34+
@final
35+
@dataclass
36+
class InvalidMeasureId(MeasureId, metaclass=SingletonMeta):
37+
38+
def is_subseteq(self, other: MeasureId) -> bool:
39+
return True
40+
41+
42+
@final
43+
@dataclass
44+
class AnyMeasureId(MeasureId, metaclass=SingletonMeta):
45+
46+
def is_subseteq(self, other: MeasureId) -> bool:
47+
return isinstance(other, AnyMeasureId)
48+
49+
50+
@final
51+
@dataclass
52+
class NotMeasureId(MeasureId, metaclass=SingletonMeta):
53+
54+
def is_subseteq(self, other: MeasureId) -> bool:
55+
return isinstance(other, NotMeasureId)
56+
57+
58+
@final
59+
@dataclass
60+
class MeasureIdBool(MeasureId):
61+
idx: int
62+
63+
def is_subseteq(self, other: MeasureId) -> bool:
64+
if isinstance(other, MeasureIdBool):
65+
return self.idx == other.idx
66+
return False
67+
68+
69+
# Might be nice to have some print override
70+
# here so all the CanMeasureId's/other types are consolidated for
71+
# readability
72+
73+
74+
@final
75+
@dataclass
76+
class MeasureIdTuple(MeasureId):
77+
data: tuple[MeasureId, ...]
78+
79+
def is_subseteq(self, other: MeasureId) -> bool:
80+
if isinstance(other, MeasureIdTuple):
81+
return all(a.is_subseteq(b) for a, b in zip(self.data, other.data))
82+
return False

src/bloqade/qasm2/passes/unroll_if.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,22 @@
77
ConstantFold,
88
CommonSubexpressionElimination,
99
)
10+
from kirin.dialects import scf, func
1011

11-
from ..rewrite.split_ifs import LiftThenBody, SplitIfStmts
12+
from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts
13+
14+
from ..dialects.uop.stmts import SingleQubitGate, TwoQubitCtrlGate
15+
from ..dialects.core.stmts import Reset, Measure
16+
17+
AllowedThenType = (SingleQubitGate, TwoQubitCtrlGate, Measure, Reset)
18+
DontLiftType = AllowedThenType + (scf.Yield, func.Return, func.Invoke)
1219

1320

1421
class UnrollIfs(Pass):
1522
"""This pass lifts statements that are not UOP out of the if body and then splits whatever is left into multiple if statements so you obtain valid QASM2"""
1623

1724
def unsafe_run(self, mt: ir.Method):
18-
result = Walk(LiftThenBody()).rewrite(mt.code)
25+
result = Walk(LiftThenBody(exclude_stmts=DontLiftType)).rewrite(mt.code)
1926
result = Walk(SplitIfStmts()).rewrite(mt.code).join(result)
2027
result = (
2128
Fixpoint(Walk(Chain(ConstantFold(), CommonSubexpressionElimination())))

src/bloqade/rewrite/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList

0 commit comments

Comments
 (0)