Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b718282
Update analysis to new kirin API
david-pl Oct 31, 2025
2fed75d
Fix fidelity analysis
david-pl Oct 31, 2025
41feea8
Fix address analysis
david-pl Oct 31, 2025
9c57ac1
Fix PyQrack interpreter
david-pl Oct 31, 2025
8985457
Fix PyQrack return value
david-pl Oct 31, 2025
e999074
Discard gate dialect AFTER rewrite to native is done
david-pl Oct 31, 2025
d0ff2d7
Fix PyQrack target
david-pl Oct 31, 2025
6e81083
Fix self arg in qasm2 loading
david-pl Oct 31, 2025
a4dd592
All PyQrack tests work
david-pl Oct 31, 2025
0aed338
Fix qasm2 parallel & global rewrites
david-pl Nov 3, 2025
cfd1fed
Fix types in qbraid lowering
david-pl Nov 3, 2025
3b2fc6e
Unmark & fix a bunch of qasm2 tests
david-pl Nov 3, 2025
681cb8b
Fix get_qubit_id test
david-pl Nov 3, 2025
5b6df40
Cleaner self arg in qasm loading
david-pl Nov 3, 2025
16bd272
Fix stim tests by removing newlines
david-pl Nov 3, 2025
8b92c5c
Fix run_lattice in address analysis
david-pl Nov 5, 2025
141ff59
Remove old code
david-pl Nov 5, 2025
fc61d02
Clean up cirq emit
david-pl Nov 5, 2025
6b20ec0
Revert change in emit constant
david-pl Nov 5, 2025
17e8ffb
Clean up cirq lowering a bit
david-pl Nov 5, 2025
9ff1d99
Merge branch 'david/571-kirin-upgrade-branch' into david/574-upgrade-…
david-pl Nov 5, 2025
636083b
Update kernel validation
david-pl Nov 5, 2025
d86cad3
Remove unneeded method from QASM2 emit
david-pl Nov 5, 2025
ef14987
Remove some commented code
david-pl Nov 5, 2025
8191481
fix dialect group for codegen test
kaihsin Nov 5, 2025
e00e129
Remove unneeded run_no_raise method
david-pl Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AddressAnalysis(Forward[Address]):
This analysis pass can be used to track the global addresses of qubits and wires.
"""

keys = ["qubit.address"]
keys = ("qubit.address",)
_const_prop: const.Propagate
lattice = Address
next_address: int = field(init=False)
Expand Down Expand Up @@ -45,7 +45,7 @@ def try_eval_const_prop(
) -> interp.StatementResult[Address]:
_frame = self._const_prop.initialize_frame(frame.code)
_frame.set_values(stmt.args, tuple(x.result for x in args))
result = self._const_prop.eval_stmt(_frame, stmt)
result = self._const_prop.frame_eval(_frame, stmt)

match result:
case interp.ReturnValue(constant_ret):
Expand Down Expand Up @@ -96,7 +96,8 @@ def run_lattice(
self,
callee: Address,
inputs: tuple[Address, ...],
kwargs: tuple[str, ...],
keys: tuple[str, ...],
kwargs: tuple[Address, ...],
) -> Address:
"""Run a callable lattice element with the given inputs and keyword arguments.

Expand All @@ -111,15 +112,16 @@ def run_lattice(
"""

match callee:
case PartialLambda(code=code, argnames=argnames):
_, ret = self.run_callable(
code, (callee,) + self.permute_values(argnames, inputs, kwargs)
case PartialLambda(code=code):
_, ret = self.call(
code, callee, *inputs, **{k: v for k, v in zip(keys, kwargs)}
)
return ret
case ConstResult(const.Value(ir.Method() as method)):
_, ret = self.run_method(
method,
self.permute_values(method.arg_names, inputs, kwargs),
_, ret = self.call(
method.code,
self.method_self(method),
*inputs,
**{k: v for k, v in zip(keys, kwargs)},
)
return ret
case _:
Expand All @@ -137,14 +139,12 @@ def get_const_value(self, addr: Address, typ: Type[T]) -> T | None:

return value

def eval_stmt_fallback(self, frame: ForwardFrame[Address], stmt: ir.Statement):
args = frame.get_values(stmt.args)
def eval_fallback(self, frame: ForwardFrame[Address], node: ir.Statement):
args = frame.get_values(node.args)
if types.is_tuple_of(args, ConstResult):
return self.try_eval_const_prop(frame, stmt, args)
return self.try_eval_const_prop(frame, node, args)

return tuple(Address.from_type(result.type) for result in stmt.results)
return tuple(Address.from_type(result.type) for result in node.results)

def run_method(self, method: ir.Method, args: tuple[Address, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
self_mt = ConstResult(const.Value(method))
return self.run_callable(method.code, (self_mt,) + args)
def method_self(self, method: ir.Method) -> Address:
return ConstResult(const.Value(method))
46 changes: 23 additions & 23 deletions src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def map_(

results = []
for ele in values:
ret = interp_.run_lattice(fn, (ele,), ())
ret = interp_.run_lattice(fn, (ele,), (), ())
results.append(ret)

if isinstance(stmt, ilist.Map):
Expand Down Expand Up @@ -180,13 +180,10 @@ def invoke(
frame: ForwardFrame[Address],
stmt: func.Invoke,
):

args = interp_.permute_values(
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
)
_, ret = interp_.run_method(
stmt.callee,
args,
_, ret = interp_.call(
stmt.callee.code,
interp_.method_self(stmt.callee),
*frame.get_values(stmt.inputs),
)

return (ret,)
Expand Down Expand Up @@ -219,7 +216,8 @@ def call(
result = interp_.run_lattice(
frame.get(stmt.callee),
frame.get_values(stmt.inputs),
stmt.kwargs,
stmt.keys,
frame.get_values(stmt.kwargs),
)
return (result,)

Expand Down Expand Up @@ -319,26 +317,28 @@ def ifelse(
):
body = stmt.then_body if const_cond.data else stmt.else_body
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
ret = interp_.run_ssacfg_region(body_frame, body, (address_cond,))
ret = interp_.frame_call_region(body_frame, stmt, body, address_cond)
# interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values())
return ret
else:
# run both branches
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
then_results = interp_.run_ssacfg_region(
then_frame, stmt.then_body, (address_cond,)
)
interp_.set_values(
frame, then_frame.entries.keys(), then_frame.entries.values()
then_results = interp_.frame_call_region(
then_frame,
stmt,
stmt.then_body,
address_cond,
)
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())

with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
else_results = interp_.run_ssacfg_region(
else_frame, stmt.else_body, (address_cond,)
)
interp_.set_values(
frame, else_frame.entries.keys(), else_frame.entries.values()
else_results = interp_.frame_call_region(
else_frame,
stmt,
stmt.else_body,
address_cond,
)
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
# TODO: pick the non-return value
if isinstance(then_results, interp.ReturnValue) and isinstance(
else_results, interp.ReturnValue
Expand All @@ -364,12 +364,12 @@ def for_loop(
iter_type, iterable = interp_.unpack_iterable(frame.get(stmt.iterable))

if iter_type is None:
return interp_.eval_stmt_fallback(frame, stmt)
return interp_.eval_fallback(frame, stmt)

for value in iterable:
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
loop_vars = interp_.run_ssacfg_region(
body_frame, stmt.body, (value,) + loop_vars
loop_vars = interp_.frame_call_region(
body_frame, stmt, stmt.body, value, *loop_vars
)

if loop_vars is None:
Expand Down
25 changes: 5 additions & 20 deletions src/bloqade/analysis/fidelity/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from kirin import ir
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward
from kirin.interp.value import Successor
from kirin.analysis.forward import ForwardFrame

from ..address import Address, AddressAnalysis
Expand Down Expand Up @@ -48,15 +47,11 @@ def main():
The fidelity of the gate set described by the analysed program. It reduces whenever a noise channel is encountered.
"""

_current_gate_fidelity: float = field(init=False)

atom_survival_probability: list[float] = field(init=False)
"""
The probabilities that each of the atoms in the register survive the duration of the analysed program. The order of the list follows the order they are in the register.
"""

_current_atom_survival_probability: list[float] = field(init=False)

addr_frame: ForwardFrame[Address] = field(init=False)

def initialize(self):
Expand All @@ -67,25 +62,15 @@ def initialize(self):
]
return self

def posthook_succ(self, frame: ForwardFrame, succ: Successor):
self.gate_fidelity *= self._current_gate_fidelity
for i, _current_survival in enumerate(self._current_atom_survival_probability):
self.atom_survival_probability[i] *= _current_survival

def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
# NOTE: default is to conserve fidelity, so do nothing here
return

def run_method(self, method: ir.Method, args: tuple[EmptyLattice, ...]):
return self.run_callable(method.code, (self.lattice.bottom(),) + args)

def run_analysis(
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
) -> tuple[ForwardFrame, Any]:
self._run_address_analysis(method, no_raise=no_raise)
return super().run(method)
def run(self, method: ir.Method, *args, **kwargs) -> tuple[ForwardFrame, Any]:
self._run_address_analysis(method)
return super().run(method, *args, **kwargs)

def _run_address_analysis(self, method: ir.Method, no_raise: bool):
def _run_address_analysis(self, method: ir.Method):
addr_analysis = AddressAnalysis(self.dialects)
addr_frame, _ = addr_analysis.run(method=method)
self.addr_frame = addr_frame
Expand Down
16 changes: 6 additions & 10 deletions src/bloqade/analysis/measure_id/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,16 @@ class MeasurementIDAnalysis(ForwardExtra[MeasureIDFrame, MeasureId]):
measure_count = 0

def initialize_frame(
self, code: ir.Statement, *, has_parent_access: bool = False
self, node: ir.Statement, *, has_parent_access: bool = False
) -> MeasureIDFrame:
return MeasureIDFrame(code, has_parent_access=has_parent_access)
return MeasureIDFrame(node, has_parent_access=has_parent_access)

# Still default to bottom,
# but let constants return the softer "NoMeasureId" type from impl
def eval_stmt_fallback(
self, frame: ForwardFrame[MeasureId], stmt: ir.Statement
def eval_fallback(
self, frame: ForwardFrame[MeasureId], node: ir.Statement
) -> tuple[MeasureId, ...]:
return tuple(NotMeasureId() for _ in stmt.results)

def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
return tuple(NotMeasureId() for _ in node.results)

# Xiu-zhe (Roger) Luo came up with this in the address analysis,
# reused here for convenience (now modified to be a bit more graceful)
Expand All @@ -45,7 +41,7 @@ def run_method(self, method: ir.Method, args: tuple[MeasureId, ...]):
T = TypeVar("T")

def get_const_value(
self, input_type: type[T], value: ir.SSAValue
self, input_type: type[T] | tuple[type[T], ...], value: ir.SSAValue
) -> type[T] | None:
if isinstance(hint := value.hints.get("const"), const.Value):
data = hint.data
Expand Down
9 changes: 4 additions & 5 deletions src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,10 @@ def return_(self, _: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Retu
def invoke(
self, interp_: MeasurementIDAnalysis, frame: interp.Frame, stmt: func.Invoke
):
_, ret = interp_.run_method(
stmt.callee,
interp_.permute_values(
stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
),
_, ret = interp_.call(
stmt.callee.code,
interp_.method_self(stmt.callee),
*frame.get_values(stmt.inputs),
)
return (ret,)

Expand Down
33 changes: 1 addition & 32 deletions src/bloqade/cirq_utils/emit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,39 +189,8 @@ def initialize_frame(
node, has_parent_access=has_parent_access, qubits=self.qubits
)

def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
return self.call(method, *args)

def run_callable_region(
self,
frame: EmitCirqFrame,
code: ir.Statement,
region: ir.Region,
args: tuple,
):
if len(region.blocks) > 0:
block_args = list(region.blocks[0].args)
# NOTE: skip self arg
frame.set_values(block_args[1:], args)

results = self.frame_eval(frame, code)
if isinstance(results, tuple):
if len(results) == 0:
return self.void
elif len(results) == 1:
return results[0]
raise interp.InterpreterError(f"Unexpected results {results}")

def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
for stmt in block.stmts:
result = self.frame_eval(frame, stmt)
if isinstance(result, tuple):
frame.set_values(stmt.results, result)

return self.circuit

def reset(self):
pass
self.circuit = cirq.Circuit()

def eval_fallback(self, frame: EmitCirqFrame, node: ir.Statement) -> tuple:
return tuple(None for _ in range(len(node.results)))
Expand Down
4 changes: 1 addition & 3 deletions src/bloqade/cirq_utils/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main():
```
"""

target = Squin(dialects=dialects, circuit=circuit)
target = Squin(dialects, circuit)
body = target.run(
circuit,
source=str(circuit), # TODO: proper source string
Expand Down Expand Up @@ -144,8 +144,6 @@ def main():
)

mt = ir.Method(
mod=None,
py_func=None,
sym_name=kernel_name,
arg_names=arg_names,
dialects=dialects,
Expand Down
6 changes: 3 additions & 3 deletions src/bloqade/gemini/analysis/logical_validation/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class GeminiLogicalValidationAnalysis(ValidationAnalysis):

first_gate = True

def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
if isinstance(stmt, squin.gate.stmts.Gate):
def eval_fallback(self, frame: ValidationFrame, node: ir.Statement):
if isinstance(node, squin.gate.stmts.Gate):
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
self.first_gate = False

return super().eval_stmt_fallback(frame, stmt)
return super().eval_fallback(frame, node)
16 changes: 9 additions & 7 deletions src/bloqade/native/upstream/squin2native.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,18 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:
all_dialects = chain.from_iterable(
ker.dialects.data for kers in old_callgraph.defs.values() for ker in kers
)
new_dialects = (
mt.dialects.union(all_dialects).discard(gate_dialect).union(kernel)
)
combined_dialects = mt.dialects.union(all_dialects).union(kernel)

out = mt.similar(new_dialects)
UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out)
CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out)
# verify all kernels in the callgraph
out = mt.similar(combined_dialects)
UpdateDialectsOnCallGraph(combined_dialects, no_raise=no_raise)(out)
CallGraphPass(combined_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(
out
)
# verify all kernels in the callgraph and discard gate dialect
out.dialects.discard(gate_dialect)
new_callgraph = CallGraph(out)
for ker in new_callgraph.edges.keys():
ker.dialects.discard(gate_dialect)
ker.verify()

return out
3 changes: 2 additions & 1 deletion src/bloqade/pyqrack/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def run(
"""
fold = Fold(mt.dialects)
fold(mt)
return self._get_interp(mt).run(mt, args, kwargs)
_, ret = self._get_interp(mt).run(mt, *args, **kwargs)
return ret

def multi_run(
self,
Expand Down
12 changes: 5 additions & 7 deletions src/bloqade/pyqrack/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ class PyQrackSimulatorTask(AbstractSimulatorTask[Param, RetType, MemoryType]):
pyqrack_interp: PyQrackInterpreter[MemoryType]

def run(self) -> RetType:
return cast(
RetType,
self.pyqrack_interp.run(
self.kernel,
args=self.args,
kwargs=self.kwargs,
),
_, ret = self.pyqrack_interp.run(
self.kernel,
*self.args,
**self.kwargs,
)
return cast(RetType, ret)

@property
def state(self) -> MemoryType:
Expand Down
Loading
Loading