diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index a9ae40e8..ee5d8414 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -192,7 +192,10 @@ def unwrap( origin_qubit = frame.get(stmt.qubit) - return (AddressWire(origin_qubit=origin_qubit),) + if isinstance(origin_qubit, AddressQubit): + return (AddressWire(origin_qubit=origin_qubit),) + else: + return (Address.top(),) @interp.impl(squin.wire.Apply) def apply( @@ -201,14 +204,7 @@ def apply( frame: ForwardFrame[Address], stmt: squin.wire.Apply, ): - - origin_qubits = tuple( - [frame.get(input_elem).origin_qubit for input_elem in stmt.inputs] - ) - new_address_wires = tuple( - [AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits] - ) - return new_address_wires + return frame.get_values(stmt.inputs) @squin.qubit.dialect.register(key="qubit.address") diff --git a/src/bloqade/noise/native/model.py b/src/bloqade/noise/native/model.py index 1804f00b..90597939 100644 --- a/src/bloqade/noise/native/model.py +++ b/src/bloqade/noise/native/model.py @@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC): params: MoveNoiseParams = field(default_factory=MoveNoiseParams) """Parameters for calculating move noise.""" - @classmethod @abc.abstractmethod def parallel_cz_errors( - cls, ctrls: List[int], qargs: List[int], rest: List[int] + self, ctrls: List[int], qargs: List[int], rest: List[int] ) -> Dict[Tuple[float, float, float, float], List[int]]: """Takes a set of ctrls and qargs and returns a noise model for all qubits.""" pass diff --git a/src/bloqade/qasm2/parse/lowering.py b/src/bloqade/qasm2/parse/lowering.py index 69496134..614732ea 100644 --- a/src/bloqade/qasm2/parse/lowering.py +++ b/src/bloqade/qasm2/parse/lowering.py @@ -200,8 +200,9 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue: elif isinstance(value, float): stmt = expr.ConstFloat(value=value) else: - raise lowering.BuildError(f"Unsupported literal type {type(value)}") - + raise lowering.BuildError( + f"Expected value of type float or int, got {type(value)}." + ) state.current_frame.push(stmt) return stmt.result @@ -216,6 +217,8 @@ def visit_MainProgram(self, state: lowering.State[ast.Node], node: ast.MainProgr dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"] elif isinstance(node.header, ast.Kirin): dialects = node.header.dialects + else: + raise lowering.BuildError(f"Unexpected node header {node.header}") for dialect in dialects: if dialect not in allowed: @@ -412,6 +415,8 @@ def visit_Bit(self, state: lowering.State[ast.Node], node: ast.Bit): stmt = core.QRegGet(reg, addr.result) elif reg.type.is_subseteq(CRegType): stmt = core.CRegGet(reg, addr.result) + else: + raise lowering.BuildError(f"Unexpected register type {reg.type}") return state.current_frame.push(stmt).result def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call): diff --git a/src/bloqade/qasm2/passes/fold.py b/src/bloqade/qasm2/passes/fold.py index afb1b880..fdc47326 100644 --- a/src/bloqade/qasm2/passes/fold.py +++ b/src/bloqade/qasm2/passes/fold.py @@ -19,7 +19,7 @@ from kirin.analysis import const from kirin.dialects import scf, ilist from kirin.ir.method import Method -from kirin.rewrite.abc import RewriteResult +from kirin.rewrite.result import RewriteResult from bloqade.qasm2.dialects import expr diff --git a/src/bloqade/qasm2/rewrite/heuristic_noise.py b/src/bloqade/qasm2/rewrite/heuristic_noise.py index 6180ccef..27167cad 100644 --- a/src/bloqade/qasm2/rewrite/heuristic_noise.py +++ b/src/bloqade/qasm2/rewrite/heuristic_noise.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, cast from dataclasses import field, dataclass from kirin import ir @@ -226,8 +226,12 @@ def rewrite_parallel_cz_gate(self, node: parallel.CZ): and isinstance(qargs, address.AddressTuple) and all(isinstance(addr, address.AddressQubit) for addr in qargs.data) ): - ctrl_qubits = list(map(lambda addr: addr.data, ctrls.data)) - qarg_qubits = list(map(lambda addr: addr.data, qargs.data)) + ctrl_qubits = list( + map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data) + ) + qarg_qubits = list( + map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data) + ) rest = sorted( set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits) ) diff --git a/src/bloqade/qasm2/rewrite/native_gates.py b/src/bloqade/qasm2/rewrite/native_gates.py index 2fb9ee6a..6f7bf851 100644 --- a/src/bloqade/qasm2/rewrite/native_gates.py +++ b/src/bloqade/qasm2/rewrite/native_gates.py @@ -70,7 +70,7 @@ def _circuit_diagram_info_(self, args): return "*", "CU" -def around(val): +def around(val) -> float: return float(np.around(val, 14)) @@ -78,7 +78,7 @@ def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float lam, theta, phi = ( # Z angle, Y angle, then Z angle cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op)) ) - return tuple(map(around, (theta, phi, lam))) + return around(theta), around(phi), around(lam) @dataclass diff --git a/src/bloqade/qasm2/rewrite/register.py b/src/bloqade/qasm2/rewrite/register.py index 3784d6f0..bddca2c8 100644 --- a/src/bloqade/qasm2/rewrite/register.py +++ b/src/bloqade/qasm2/rewrite/register.py @@ -1,6 +1,7 @@ from kirin import ir from kirin.dialects import py -from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.rewrite.abc import RewriteRule +from kirin.rewrite.result import RewriteResult from bloqade.qasm2.dialects import core diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index aa7d5bb7..6a69c970 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -3,9 +3,10 @@ from dataclasses import field, dataclass from kirin import ir -from kirin.rewrite import abc as rewrite_abc from kirin.dialects import py, ilist +from kirin.rewrite.abc import RewriteRule from kirin.analysis.const import lattice +from kirin.rewrite.result import RewriteResult from bloqade.analysis import address from bloqade.qasm2.dialects import uop, core, parallel @@ -14,7 +15,7 @@ class MergePolicyABC(abc.ABC): @abc.abstractmethod - def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + def __call__(self, node: ir.Statement) -> RewriteResult: pass @classmethod @@ -141,10 +142,10 @@ def from_analysis( group_numbers=group_numbers, ) - def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + def __call__(self, node: ir.Statement) -> RewriteResult: if node not in self.group_numbers: - return rewrite_abc.RewriteResult() + return RewriteResult() group_number = self.group_numbers[node] group = self.merge_groups[group_number] @@ -157,9 +158,7 @@ def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult: if self.group_has_merged[group_number]: node.delete() - return rewrite_abc.RewriteResult( - has_done_something=self.group_has_merged[group_number] - ) + return RewriteResult(has_done_something=self.group_has_merged[group_number]) def move_and_collect_qubit_list( self, qargs: List[ir.SSAValue], node: ir.Statement @@ -219,14 +218,14 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]): ctrls.append(stmt.ctrls) qargs.append(stmt.qargs) else: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) ctrls_values = self.move_and_collect_qubit_list(ctrls, node) qargs_values = self.move_and_collect_qubit_list(qargs, node) if ctrls_values is None or qargs_values is None: # give up if we cannot determine the address or cannot move the qubits - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_ctrls = ilist.New(values=ctrls_values) new_qargs = ilist.New(values=qargs_values) @@ -238,7 +237,7 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]): node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]): return self.rewrite_group_u(node, group) @@ -252,13 +251,13 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]): elif isinstance(stmt, parallel.UGate): qargs.append(stmt.qargs) else: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) assert isinstance(node, (uop.UGate, parallel.UGate)) qargs_values = self.move_and_collect_qubit_list(qargs, node) if qargs_values is None: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_qargs = ilist.New(values=qargs_values) new_gate = parallel.UGate( @@ -271,7 +270,7 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]): new_gate.insert_before(node) node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]): qargs = [] @@ -282,14 +281,14 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]): elif isinstance(stmt, parallel.RZ): qargs.append(stmt.qargs) else: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) assert isinstance(node, (uop.RZ, parallel.RZ)) qargs_values = self.move_and_collect_qubit_list(qargs, node) if qargs_values is None: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_qargs = ilist.New(values=qargs_values) new_gate = parallel.RZ( @@ -300,7 +299,7 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]): new_gate.insert_before(node) node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]): qargs = [] @@ -310,13 +309,13 @@ def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]): qargs_values = self.move_and_collect_qubit_list(qargs, node) if qargs_values is None: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_node = uop.Barrier(qargs=qargs_values) new_node.insert_before(node) node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) class GreedyMixin(MergePolicyABC): @@ -385,11 +384,11 @@ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy): @dataclass -class UOpToParallelRule(rewrite_abc.RewriteRule): +class UOpToParallelRule(RewriteRule): merge_rewriters: Dict[ir.Block | None, MergePolicyABC] - def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: merge_rewriter = self.merge_rewriters.get( - node.parent_block, lambda _: rewrite_abc.RewriteResult() + node.parent_block, lambda _: RewriteResult() ) return merge_rewriter(node) diff --git a/src/bloqade/qbraid/schema.py b/src/bloqade/qbraid/schema.py index ae103cd9..54eed1c2 100644 --- a/src/bloqade/qbraid/schema.py +++ b/src/bloqade/qbraid/schema.py @@ -9,7 +9,7 @@ class Operation(BaseModel, frozen=True, extra="forbid"): op_type: str = Field(init=False) -class CZ(Operation): +class CZ(Operation, frozen=True): """A CZ gate operation. Fields: @@ -22,7 +22,7 @@ class CZ(Operation): participants: Tuple[Union[Tuple[int], Tuple[int, int]], ...] -class GlobalRz(Operation): +class GlobalRz(Operation, frozen=True): """GlobalRz operation. Fields: @@ -34,7 +34,7 @@ class GlobalRz(Operation): phi: float -class GlobalW(Operation): +class GlobalW(Operation, frozen=True): """GlobalW operation. Fields: @@ -48,7 +48,7 @@ class GlobalW(Operation): phi: float -class LocalRz(Operation): +class LocalRz(Operation, frozen=True): """LocalRz operation. Fields: @@ -63,7 +63,7 @@ class LocalRz(Operation): phi: float -class LocalW(Operation): +class LocalW(Operation, frozen=True): """LocalW operation. Fields: @@ -80,7 +80,7 @@ class LocalW(Operation): phi: float -class Measurement(Operation): +class Measurement(Operation, frozen=True): """Measurement operation. Fields: @@ -95,9 +95,7 @@ class Measurement(Operation): participants: Tuple[int, ...] -OperationType = TypeVar( - "OperationType", bound=Union[CZ, GlobalRz, GlobalW, LocalRz, LocalW, Measurement] -) +OperationType = CZ | GlobalRz | GlobalW | LocalRz | LocalW | Measurement class ErrorModel(BaseModel, frozen=True, extra="forbid"): @@ -106,7 +104,7 @@ class ErrorModel(BaseModel, frozen=True, extra="forbid"): error_model_type: str = Field(init=False) -class PauliErrorModel(ErrorModel): +class PauliErrorModel(ErrorModel, frozen=True): """Pauli error model. Fields: @@ -131,7 +129,7 @@ class ErrorOperation(BaseModel, Generic[ErrorModelType], frozen=True, extra="for survival_prob: Tuple[float, ...] -class CZError(ErrorOperation[ErrorModelType]): +class CZError(ErrorOperation[ErrorModelType], frozen=True): """CZError operation. Fields: @@ -149,7 +147,7 @@ class CZError(ErrorOperation[ErrorModelType]): single_error: ErrorModelType -class SingleQubitError(ErrorOperation[ErrorModelType]): +class SingleQubitError(ErrorOperation[ErrorModelType], frozen=True): """SingleQubitError operation. Fields: diff --git a/src/bloqade/squin/analysis/nsites/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py index 24930ba8..52bd3091 100644 --- a/src/bloqade/squin/analysis/nsites/analysis.py +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -23,11 +23,11 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): if method is not None: return method(self, frame, stmt) elif stmt.has_trait(HasSites): - has_sites_trait = stmt.get_trait(HasSites) + has_sites_trait = stmt.get_present_trait(HasSites) sites = has_sites_trait.get_sites(stmt) return (NumberSites(sites=sites),) elif stmt.has_trait(FixedSites): - sites_trait = stmt.get_trait(FixedSites) + sites_trait = stmt.get_present_trait(FixedSites) return (NumberSites(sites=sites_trait.data),) else: return (NoSites(),) diff --git a/src/bloqade/visual/animation/base.py b/src/bloqade/visual/animation/base.py index 870b7322..08ec9825 100644 --- a/src/bloqade/visual/animation/base.py +++ b/src/bloqade/visual/animation/base.py @@ -45,10 +45,12 @@ def get_artists(self) -> Tuple[Any]: class GlobalGateArtist(GateArtist): mpl_obj: mpatches.Rectangle - def __init__(self, mpl_ax: Any, xmin, ymin, width, height, color): + def __init__( + self, mpl_ax: Any, xmin: float, ymin: float, width: float, height: float, color + ): super().__init__(mpl_ax) rc = mpatches.Rectangle( - [xmin, ymin], width, height, color=color, alpha=0.6, visible=False + (xmin, ymin), width, height, color=color, alpha=0.6, visible=False ) mpl_ax.add_patch(rc) self.mpl_obj = rc @@ -56,7 +58,7 @@ def __init__(self, mpl_ax: Any, xmin, ymin, width, height, color): def clear_data(self) -> None: self.mpl_obj.set_width(0) self.mpl_obj.set_height(0) - self.mpl_obj.set_xy([0, 0]) + self.mpl_obj.set_xy((0, 0)) def get_artists(self) -> Tuple[Any]: return (self.mpl_obj,) @@ -86,7 +88,7 @@ def __init__( self.width = width self.xmin = xmin rc_btm = mpatches.Rectangle( - [xmin, ymin_keepout], + (xmin, ymin_keepout), width, ymin - ymin_keepout, color=color, @@ -97,13 +99,13 @@ def __init__( self.mpl_obj_keepout_btm = rc_btm rc = mpatches.Rectangle( - [xmin, ymin], width, ymax - ymin, color=color, alpha=0.6, visible=False + (xmin, ymin), width, ymax - ymin, color=color, alpha=0.6, visible=False ) mpl_ax.add_patch(rc) self.mpl_obj = rc rc_top = mpatches.Rectangle( - [xmin, ymax], + (xmin, ymax), width, ymax_keepout - ymax, color=color, @@ -116,31 +118,31 @@ def __init__( def clear_data(self) -> None: self.mpl_obj.set_width(0) self.mpl_obj.set_height(0) - self.mpl_obj.set_xy([0, 0]) + self.mpl_obj.set_xy((0, 0)) self.mpl_obj_keepout_top.set_width(0) self.mpl_obj_keepout_top.set_height(0) - self.mpl_obj_keepout_top.set_xy([0, 0]) + self.mpl_obj_keepout_top.set_xy((0, 0)) self.mpl_obj_keepout_btm.set_width(0) self.mpl_obj_keepout_btm.set_height(0) - self.mpl_obj_keepout_btm.set_xy([0, 0]) + self.mpl_obj_keepout_btm.set_xy((0, 0)) - def get_artists(self) -> Tuple[Any]: + def get_artists(self) -> Tuple[Any, ...]: return (self.mpl_obj, self.mpl_obj_keepout_top, self.mpl_obj_keepout_btm) def update_data(self, ymin, ymax, ymin_keepout, ymax_keepout): self.mpl_obj.set_height(ymax - ymin) self.mpl_obj.set_width(self.width) - self.mpl_obj.set_xy([self.xmin, ymin]) + self.mpl_obj.set_xy((self.xmin, ymin)) self.mpl_obj_keepout_top.set_height(ymax_keepout - ymax) self.mpl_obj_keepout_top.set_width(self.width) - self.mpl_obj_keepout_top.set_xy([self.xmin, ymax]) + self.mpl_obj_keepout_top.set_xy((self.xmin, ymax)) self.mpl_obj_keepout_btm.set_height(ymin - ymin_keepout) self.mpl_obj_keepout_btm.set_width(self.width) - self.mpl_obj_keepout_btm.set_xy([self.xmin, ymin_keepout]) + self.mpl_obj_keepout_btm.set_xy((self.xmin, ymin_keepout)) def set_visible(self, visible: bool): self.mpl_obj.set_visible(visible) @@ -194,11 +196,19 @@ def not_defined(self): ) @property - def width(self): + def width(self) -> float: + if self.xmax is None or self.xmin is None: + raise ValueError( + "Can't return width of FOV as either xmin or xmax are undefined" + ) return self.xmax - self.xmin @property - def height(self): + def height(self) -> float: + if self.ymax is None or self.ymin is None: + raise ValueError( + "Can't return width of FOV as either ymin or ymax are undefined" + ) return self.ymax - self.ymin def to_json(self):