Skip to content

Commit 8f35f99

Browse files
david-plkaihsin
andauthored
Update remaining code base to new kirin version (#600)
CI is blocked by: * Requires kirin main * This kirin PR: QuEraComputing/kirin#563 * This kirin issue: QuEraComputing/kirin#564 Other than that, we should be good. I'm still targeting the kirin upgrade branch to make review easier, but this once the above issues are resolved, this can actually go into `main`. @weinbe58 please have a look at the address analysis, specifically at the `run_lattice` method. I had to change the signature a bit. --------- Co-authored-by: kaihsin <[email protected]>
1 parent 435b33b commit 8f35f99

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+188
-283
lines changed

src/bloqade/analysis/address/analysis.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class AddressAnalysis(Forward[Address]):
1515
This analysis pass can be used to track the global addresses of qubits and wires.
1616
"""
1717

18-
keys = ["qubit.address"]
18+
keys = ("qubit.address",)
1919
_const_prop: const.Propagate
2020
lattice = Address
2121
next_address: int = field(init=False)
@@ -45,7 +45,7 @@ def try_eval_const_prop(
4545
) -> interp.StatementResult[Address]:
4646
_frame = self._const_prop.initialize_frame(frame.code)
4747
_frame.set_values(stmt.args, tuple(x.result for x in args))
48-
result = self._const_prop.eval_stmt(_frame, stmt)
48+
result = self._const_prop.frame_eval(_frame, stmt)
4949

5050
match result:
5151
case interp.ReturnValue(constant_ret):
@@ -96,7 +96,8 @@ def run_lattice(
9696
self,
9797
callee: Address,
9898
inputs: tuple[Address, ...],
99-
kwargs: tuple[str, ...],
99+
keys: tuple[str, ...],
100+
kwargs: tuple[Address, ...],
100101
) -> Address:
101102
"""Run a callable lattice element with the given inputs and keyword arguments.
102103
@@ -111,15 +112,16 @@ def run_lattice(
111112
"""
112113

113114
match callee:
114-
case PartialLambda(code=code, argnames=argnames):
115-
_, ret = self.run_callable(
116-
code, (callee,) + self.permute_values(argnames, inputs, kwargs)
115+
case PartialLambda(code=code):
116+
_, ret = self.call(
117+
code, callee, *inputs, **{k: v for k, v in zip(keys, kwargs)}
117118
)
118-
return ret
119119
case ConstResult(const.Value(ir.Method() as method)):
120-
_, ret = self.run_method(
121-
method,
122-
self.permute_values(method.arg_names, inputs, kwargs),
120+
_, ret = self.call(
121+
method.code,
122+
self.method_self(method),
123+
*inputs,
124+
**{k: v for k, v in zip(keys, kwargs)},
123125
)
124126
return ret
125127
case _:
@@ -137,14 +139,12 @@ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:
137139

138140
return value
139141

140-
def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement):
141-
args = frame.get_values(stmt.args)
142+
def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement):
143+
args = frame.get_values(node.args)
142144
if types.is_tuple_of(args, ConstResult):
143-
return self.try_eval_const_prop(frame, stmt, args)
145+
return self.try_eval_const_prop(frame, node, args)
144146

145-
return tuple(Address.from_type(result.type) for result in stmt.results)
147+
return tuple(Address.from_type(result.type) for result in node.results)
146148

147-
def run_method(self, method: ir.Method, args: tuple[Address, ...]):
148-
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
149-
self_mt = ConstResult(const.Value(method))
150-
return self.run_callable(method.code, (self_mt,) + args)
149+
def method_self(self, method: ir.Method) -> Address:
150+
return ConstResult(const.Value(method))

src/bloqade/analysis/address/impls.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def map_(
9797

9898
results = []
9999
for ele in values:
100-
ret = interp_.run_lattice(fn, (ele,), ())
100+
ret = interp_.run_lattice(fn, (ele,), (), ())
101101
results.append(ret)
102102

103103
if isinstance(stmt, ilist.Map):
@@ -180,13 +180,10 @@ def invoke(
180180
frame: ForwardFrame[Address],
181181
stmt: func.Invoke,
182182
):
183-
184-
args = interp_.permute_values(
185-
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
186-
)
187-
_, ret = interp_.run_method(
188-
stmt.callee,
189-
args,
183+
_, ret = interp_.call(
184+
stmt.callee.code,
185+
interp_.method_self(stmt.callee),
186+
*frame.get_values(stmt.inputs),
190187
)
191188

192189
return (ret,)
@@ -219,7 +216,8 @@ def call(
219216
result = interp_.run_lattice(
220217
frame.get(stmt.callee),
221218
frame.get_values(stmt.inputs),
222-
stmt.kwargs,
219+
stmt.keys,
220+
frame.get_values(stmt.kwargs),
223221
)
224222
return (result,)
225223

@@ -319,26 +317,28 @@ def ifelse(
319317
):
320318
body = stmt.then_body if const_cond.data else stmt.else_body
321319
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
322-
ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,))
320+
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
323321
# interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
324322
return ret
325323
else:
326324
# run both branches
327325
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
328-
then_results = interp_.run_ssacfg_region(
329-
then_frame, stmt.then_body, (address_cond,)
330-
)
331-
interp_.set_values(
332-
frame, then_frame.entries.keys(), then_frame.entries.values()
326+
then_results = interp_.frame_call_region(
327+
then_frame,
328+
stmt,
329+
stmt.then_body,
330+
address_cond,
333331
)
332+
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
334333

335334
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
336-
else_results = interp_.run_ssacfg_region(
337-
else_frame, stmt.else_body, (address_cond,)
338-
)
339-
interp_.set_values(
340-
frame, else_frame.entries.keys(), else_frame.entries.values()
335+
else_results = interp_.frame_call_region(
336+
else_frame,
337+
stmt,
338+
stmt.else_body,
339+
address_cond,
341340
)
341+
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
342342
# TODO: pick the non-return value
343343
if isinstance(then_results, interp.ReturnValue) and isinstance(
344344
else_results, interp.ReturnValue
@@ -364,12 +364,12 @@ def for_loop(
364364
iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable))
365365

366366
if iter_type is None:
367-
return interp_.eval_stmt_fallback(frame, stmt)
367+
return interp_.eval_fallback(frame, stmt)
368368

369369
for value in iterable:
370370
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
371-
loop_vars = interp_.run_ssacfg_region(
372-
body_frame, stmt.body, (value,) + loop_vars
371+
loop_vars = interp_.frame_call_region(
372+
body_frame, stmt, stmt.body, value, *loop_vars
373373
)
374374

375375
if loop_vars is None:

src/bloqade/analysis/fidelity/analysis.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from kirin import ir
55
from kirin.lattice import EmptyLattice
66
from kirin.analysis import Forward
7-
from kirin.interp.value import Successor
87
from kirin.analysis.forward import ForwardFrame
98

109
from ..address import Address, AddressAnalysis
@@ -48,15 +47,11 @@ def main():
4847
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
4948
"""
5049

51-
_current_gate_fidelity: float = field(init=False)
52-
5350
atom_survival_probability: list[float] = field(init=False)
5451
"""
5552
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
5653
"""
5754

58-
_current_atom_survival_probability: list[float] = field(init=False)
59-
6055
addr_frame: ForwardFrame[Address] = field(init=False)
6156

6257
def initialize(self):
@@ -67,25 +62,15 @@ def initialize(self):
6762
]
6863
return self
6964

70-
def posthook_succ(self, frame: ForwardFrame, succ: Successor):
71-
self.gate_fidelity *= self._current_gate_fidelity
72-
for i, _current_survival in enumerate(self._current_atom_survival_probability):
73-
self.atom_survival_probability[i] *= _current_survival
74-
75-
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
65+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
7666
# NOTE: default is to conserve fidelity, so do nothing here
7767
return
7868

79-
def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
80-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
81-
82-
def run_analysis(
83-
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
84-
) -> tuple[ForwardFrame, Any]:
85-
self._run_address_analysis(method, no_raise=no_raise)
86-
return super().run(method)
69+
def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
70+
self._run_address_analysis(method)
71+
return super().run(method, *args, **kwargs)
8772

88-
def _run_address_analysis(self, method: ir.Method, no_raise: bool):
73+
def _run_address_analysis(self, method: ir.Method):
8974
addr_analysis = AddressAnalysis(self.dialects)
9075
addr_frame, _ = addr_analysis.run(method=method)
9176
self.addr_frame = addr_frame

src/bloqade/analysis/measure_id/analysis.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,16 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
2222
measure_count = 0
2323

2424
def initialize_frame(
25-
self, code: ir.Statement, *, has_parent_access: bool = False
25+
self, node: ir.Statement, *, has_parent_access: bool = False
2626
) -> MeasureIDFrame:
27-
return MeasureIDFrame(code, has_parent_access=has_parent_access)
27+
return MeasureIDFrame(node, has_parent_access=has_parent_access)
2828

2929
# Still default to bottom,
3030
# but let constants return the softer "NoMeasureId" type from impl
31-
def eval_stmt_fallback(
32-
self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
31+
def eval_fallback(
32+
self, frame: ForwardFrame[MeasureId], node: ir.Statement
3333
) -> tuple[MeasureId, ...]:
34-
return tuple(NotMeasureId() for _ in stmt.results)
35-
36-
def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
37-
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
38-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
34+
return tuple(NotMeasureId() for _ in node.results)
3935

4036
# Xiu-zhe (Roger) Luo came up with this in the address analysis,
4137
# reused here for convenience (now modified to be a bit more graceful)
@@ -45,7 +41,7 @@ def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
4541
T = TypeVar("T")
4642

4743
def get_const_value(
48-
self, input_type: type[T], value: ir.SSAValue
44+
self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
4945
) -> type[T] | None:
5046
if isinstance(hint := value.hints.get("const"), const.Value):
5147
data = hint.data

src/bloqade/analysis/measure_id/impls.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,10 @@ def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Retu
138138
def invoke(
139139
self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
140140
):
141-
_, ret = interp_.run_method(
142-
stmt.callee,
143-
interp_.permute_values(
144-
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
145-
),
141+
_, ret = interp_.call(
142+
stmt.callee.code,
143+
interp_.method_self(stmt.callee),
144+
*frame.get_values(stmt.inputs),
146145
)
147146
return (ret,)
148147

src/bloqade/cirq_utils/emit/base.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -189,39 +189,8 @@ def initialize_frame(
189189
node, has_parent_access=has_parent_access, qubits=self.qubits
190190
)
191191

192-
def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
193-
return self.call(method, *args)
194-
195-
def run_callable_region(
196-
self,
197-
frame: EmitCirqFrame,
198-
code: ir.Statement,
199-
region: ir.Region,
200-
args: tuple,
201-
):
202-
if len(region.blocks) > 0:
203-
block_args = list(region.blocks[0].args)
204-
# NOTE: skip self arg
205-
frame.set_values(block_args[1:], args)
206-
207-
results = self.frame_eval(frame, code)
208-
if isinstance(results, tuple):
209-
if len(results) == 0:
210-
return self.void
211-
elif len(results) == 1:
212-
return results[0]
213-
raise interp.InterpreterError(f"Unexpected results {results}")
214-
215-
def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
216-
for stmt in block.stmts:
217-
result = self.frame_eval(frame, stmt)
218-
if isinstance(result, tuple):
219-
frame.set_values(stmt.results, result)
220-
221-
return self.circuit
222-
223192
def reset(self):
224-
pass
193+
self.circuit = cirq.Circuit()
225194

226195
def eval_fallback(self, frame: EmitCirqFrame, node: ir.Statement) -> tuple:
227196
return tuple(None for _ in range(len(node.results)))

src/bloqade/cirq_utils/lowering.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def main():
9999
```
100100
"""
101101

102-
target = Squin(dialects=dialects, circuit=circuit)
102+
target = Squin(dialects, circuit)
103103
body = target.run(
104104
circuit,
105105
source=str(circuit), # TODO: proper source string
@@ -144,8 +144,6 @@ def main():
144144
)
145145

146146
mt = ir.Method(
147-
mod=None,
148-
py_func=None,
149147
sym_name=kernel_name,
150148
arg_names=arg_names,
151149
dialects=dialects,

src/bloqade/gemini/analysis/logical_validation/analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ class GeminiLogicalValidationAnalysis(ValidationAnalysis):
99

1010
first_gate = True
1111

12-
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
13-
if isinstance(stmt, squin.gate.stmts.Gate):
12+
def eval_fallback(self, frame: ValidationFrame, node: ir.Statement):
13+
if isinstance(node, squin.gate.stmts.Gate):
1414
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
1515
self.first_gate = False
1616

17-
return super().eval_stmt_fallback(frame, stmt)
17+
return super().eval_fallback(frame, node)

src/bloqade/native/upstream/squin2native.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,18 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:
6262
all_dialects = chain.from_iterable(
6363
ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers
6464
)
65-
new_dialects = (
66-
mt.dialects.union(all_dialects).discard(gate_dialect).union(kernel)
67-
)
65+
combined_dialects = mt.dialects.union(all_dialects).union(kernel)
6866

69-
out = mt.similar(new_dialects)
70-
UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out)
71-
CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out)
72-
# verify all kernels in the callgraph
67+
out = mt.similar(combined_dialects)
68+
UpdateDialectsOnCallGraph(combined_dialects, no_raise=no_raise)(out)
69+
CallGraphPass(combined_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(
70+
out
71+
)
72+
# verify all kernels in the callgraph and discard gate dialect
73+
out.dialects.discard(gate_dialect)
7374
new_callgraph = CallGraph(out)
7475
for ker in new_callgraph.edges.keys():
76+
ker.dialects.discard(gate_dialect)
7577
ker.verify()
7678

7779
return out

src/bloqade/pyqrack/target.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def run(
8787
"""
8888
fold = Fold(mt.dialects)
8989
fold(mt)
90-
return self._get_interp(mt).run(mt, args, kwargs)
90+
_, ret = self._get_interp(mt).run(mt, *args, **kwargs)
91+
return ret
9192

9293
def multi_run(
9394
self,

0 commit comments

Comments
 (0)