Skip to content

Commit 7aa7e60

Browse files
david-plweinbe58zhenrongliewkaihsin
authored
Upgrade to kirin v0.21 (#572)
Co-authored-by: Phillip Weinberg <[email protected]> Co-authored-by: Dennis Liew <[email protected]> Co-authored-by: kaihsin <[email protected]> Co-authored-by: kaihsin <[email protected]>
1 parent b450a22 commit 7aa7e60

File tree

111 files changed

+954
-663
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+954
-663
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ requires-python = ">=3.10"
1313
dependencies = [
1414
"numpy>=1.22.0",
1515
"scipy>=1.13.1",
16-
"kirin-toolchain~=0.17.30",
16+
"kirin-toolchain~=0.21.0",
1717
"rich>=13.9.4",
1818
"pydantic>=1.3.0,<2.11.0",
1919
"pandas>=2.2.3",

src/bloqade/analysis/address/analysis.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class AddressAnalysis(Forward[Address]):
1515
This analysis pass can be used to track the global addresses of qubits and wires.
1616
"""
1717

18-
keys = ["qubit.address"]
18+
keys = ("qubit.address",)
1919
_const_prop: const.Propagate
2020
lattice = Address
2121
next_address: int = field(init=False)
@@ -45,7 +45,7 @@ def try_eval_const_prop(
4545
) -> interp.StatementResult[Address]:
4646
_frame = self._const_prop.initialize_frame(frame.code)
4747
_frame.set_values(stmt.args, tuple(x.result for x in args))
48-
result = self._const_prop.eval_stmt(_frame, stmt)
48+
result = self._const_prop.frame_eval(_frame, stmt)
4949

5050
match result:
5151
case interp.ReturnValue(constant_ret):
@@ -96,7 +96,8 @@ def run_lattice(
9696
self,
9797
callee: Address,
9898
inputs: tuple[Address, ...],
99-
kwargs: tuple[str, ...],
99+
keys: tuple[str, ...],
100+
kwargs: tuple[Address, ...],
100101
) -> Address:
101102
"""Run a callable lattice element with the given inputs and keyword arguments.
102103
@@ -111,15 +112,16 @@ def run_lattice(
111112
"""
112113

113114
match callee:
114-
case PartialLambda(code=code, argnames=argnames):
115-
_, ret = self.run_callable(
116-
code, (callee,) + self.permute_values(argnames, inputs, kwargs)
115+
case PartialLambda(code=code):
116+
_, ret = self.call(
117+
code, callee, *inputs, **{k: v for k, v in zip(keys, kwargs)}
117118
)
118-
return ret
119119
case ConstResult(const.Value(ir.Method() as method)):
120-
_, ret = self.run_method(
121-
method,
122-
self.permute_values(method.arg_names, inputs, kwargs),
120+
_, ret = self.call(
121+
method.code,
122+
self.method_self(method),
123+
*inputs,
124+
**{k: v for k, v in zip(keys, kwargs)},
123125
)
124126
return ret
125127
case _:
@@ -137,14 +139,12 @@ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
137139

138140
return value
139141

140-
def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement):
141-
args = frame.get_values(stmt.args)
142+
def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement):
143+
args = frame.get_values(node.args)
142144
if types.is_tuple_of(args, ConstResult):
143-
return self.try_eval_const_prop(frame, stmt, args)
145+
return self.try_eval_const_prop(frame, node, args)
144146

145-
return tuple(Address.from_type(result.type) for result in stmt.results)
147+
return tuple(Address.from_type(result.type) for result in node.results)
146148

147-
def run_method(self, method: ir.Method, args: tuple[Address, ...]):
148-
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
149-
self_mt = ConstResult(const.Value(method))
150-
return self.run_callable(method.code, (self_mt,) + args)
149+
def method_self(self, method: ir.Method) -> Address:
150+
return ConstResult(const.Value(method))

src/bloqade/analysis/address/impls.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def map_(
9797

9898
results = []
9999
for ele in values:
100-
ret = interp_.run_lattice(fn, (ele,), ())
100+
ret = interp_.run_lattice(fn, (ele,), (), ())
101101
results.append(ret)
102102

103103
if isinstance(stmt, ilist.Map):
@@ -180,13 +180,10 @@ def invoke(
180180
frame: ForwardFrame[Address],
181181
stmt: func.Invoke,
182182
):
183-
184-
args = interp_.permute_values(
185-
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
186-
)
187-
_, ret = interp_.run_method(
188-
stmt.callee,
189-
args,
183+
_, ret = interp_.call(
184+
stmt.callee.code,
185+
interp_.method_self(stmt.callee),
186+
*frame.get_values(stmt.inputs),
190187
)
191188

192189
return (ret,)
@@ -219,7 +216,8 @@ def call(
219216
result = interp_.run_lattice(
220217
frame.get(stmt.callee),
221218
frame.get_values(stmt.inputs),
222-
stmt.kwargs,
219+
stmt.keys,
220+
frame.get_values(stmt.kwargs),
223221
)
224222
return (result,)
225223

@@ -319,26 +317,28 @@ def ifelse(
319317
):
320318
body = stmt.then_body if const_cond.data else stmt.else_body
321319
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
322-
ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,))
320+
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
323321
# interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
324322
return ret
325323
else:
326324
# run both branches
327325
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
328-
then_results = interp_.run_ssacfg_region(
329-
then_frame, stmt.then_body, (address_cond,)
330-
)
331-
interp_.set_values(
332-
frame, then_frame.entries.keys(), then_frame.entries.values()
326+
then_results = interp_.frame_call_region(
327+
then_frame,
328+
stmt,
329+
stmt.then_body,
330+
address_cond,
333331
)
332+
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
334333

335334
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
336-
else_results = interp_.run_ssacfg_region(
337-
else_frame, stmt.else_body, (address_cond,)
338-
)
339-
interp_.set_values(
340-
frame, else_frame.entries.keys(), else_frame.entries.values()
335+
else_results = interp_.frame_call_region(
336+
else_frame,
337+
stmt,
338+
stmt.else_body,
339+
address_cond,
341340
)
341+
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
342342
# TODO: pick the non-return value
343343
if isinstance(then_results, interp.ReturnValue) and isinstance(
344344
else_results, interp.ReturnValue
@@ -364,12 +364,12 @@ def for_loop(
364364
iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable))
365365

366366
if iter_type is None:
367-
return interp_.eval_stmt_fallback(frame, stmt)
367+
return interp_.eval_fallback(frame, stmt)
368368

369369
for value in iterable:
370370
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
371-
loop_vars = interp_.run_ssacfg_region(
372-
body_frame, stmt.body, (value,) + loop_vars
371+
loop_vars = interp_.frame_call_region(
372+
body_frame, stmt, stmt.body, value, *loop_vars
373373
)
374374

375375
if loop_vars is None:

src/bloqade/analysis/fidelity/analysis.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from kirin import ir
55
from kirin.lattice import EmptyLattice
66
from kirin.analysis import Forward
7-
from kirin.interp.value import Successor
87
from kirin.analysis.forward import ForwardFrame
98

109
from ..address import Address, AddressAnalysis
@@ -48,15 +47,11 @@ def main():
4847
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
4948
"""
5049

51-
_current_gate_fidelity: float = field(init=False)
52-
5350
atom_survival_probability: list[float] = field(init=False)
5451
"""
5552
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
5653
"""
5754

58-
_current_atom_survival_probability: list[float] = field(init=False)
59-
6055
addr_frame: ForwardFrame[Address] = field(init=False)
6156

6257
def initialize(self):
@@ -67,28 +62,21 @@ def initialize(self):
6762
]
6863
return self
6964

70-
def posthook_succ(self, frame: ForwardFrame, succ: Successor):
71-
self.gate_fidelity *= self._current_gate_fidelity
72-
for i, _current_survival in enumerate(self._current_atom_survival_probability):
73-
self.atom_survival_probability[i] *= _current_survival
74-
75-
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
65+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
7666
# NOTE: default is to conserve fidelity, so do nothing here
7767
return
7868

79-
def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
80-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
69+
def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
70+
self._run_address_analysis(method)
71+
return super().run(method, *args, **kwargs)
8172

82-
def run_analysis(
83-
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
84-
) -> tuple[ForwardFrame, Any]:
85-
self._run_address_analysis(method, no_raise=no_raise)
86-
return super().run_analysis(method, args, no_raise=no_raise)
87-
88-
def _run_address_analysis(self, method: ir.Method, no_raise: bool):
73+
def _run_address_analysis(self, method: ir.Method):
8974
addr_analysis = AddressAnalysis(self.dialects)
90-
addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
75+
addr_frame, _ = addr_analysis.run(method=method)
9176
self.addr_frame = addr_frame
9277

9378
# NOTE: make sure we have as many probabilities as we have addresses
9479
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count
80+
81+
def method_self(self, method: ir.Method) -> EmptyLattice:
82+
return self.lattice.bottom()

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,16 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
2222
measure_count = 0
2323

2424
def initialize_frame(
25-
self, code: ir.Statement, *, has_parent_access: bool = False
25+
self, node: ir.Statement, *, has_parent_access: bool = False
2626
) -> MeasureIDFrame:
27-
return MeasureIDFrame(code, has_parent_access=has_parent_access)
27+
return MeasureIDFrame(node, has_parent_access=has_parent_access)
2828

2929
# Still default to bottom,
3030
# but let constants return the softer "NoMeasureId" type from impl
31-
def eval_stmt_fallback(
32-
self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
31+
def eval_fallback(
32+
self, frame: ForwardFrame[MeasureId], node: ir.Statement
3333
) -> tuple[MeasureId, ...]:
34-
return tuple(NotMeasureId() for _ in stmt.results)
35-
36-
def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
37-
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
38-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
34+
return tuple(NotMeasureId() for _ in node.results)
3935

4036
# Xiu-zhe (Roger) Luo came up with this in the address analysis,
4137
# reused here for convenience (now modified to be a bit more graceful)
@@ -45,11 +41,14 @@ def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
4541
T = TypeVar("T")
4642

4743
def get_const_value(
48-
self, input_type: type[T], value: ir.SSAValue
44+
self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
4945
) -> type[T] | None:
5046
if isinstance(hint := value.hints.get("const"), const.Value):
5147
data = hint.data
5248
if isinstance(data, input_type):
5349
return hint.data
5450

5551
return None
52+
53+
def method_self(self, method: ir.Method) -> MeasureId:
54+
return self.lattice.bottom()

src/bloqade/analysis/measure_id/impls.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,10 @@ def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Retu
152152
def invoke(
153153
self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
154154
):
155-
_, ret = interp_.run_method(
156-
stmt.callee,
157-
interp_.permute_values(
158-
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
159-
),
155+
_, ret = interp_.call(
156+
stmt.callee.code,
157+
interp_.method_self(stmt.callee),
158+
*frame.get_values(stmt.inputs),
160159
)
161160
return (ret,)
162161

0 commit comments

Comments
 (0)