From 57f1102b6fa3e2f736bac29b97fbe52ae7610f59 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 10:07:06 +0200 Subject: [PATCH 01/15] Fix typing in noise --- src/bloqade/noise/native/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/bloqade/noise/native/model.py b/src/bloqade/noise/native/model.py index 1804f00b..90597939 100644 --- a/src/bloqade/noise/native/model.py +++ b/src/bloqade/noise/native/model.py @@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC): params: MoveNoiseParams = field(default_factory=MoveNoiseParams) """Parameters for calculating move noise.""" - @classmethod @abc.abstractmethod def parallel_cz_errors( - cls, ctrls: List[int], qargs: List[int], rest: List[int] + self, ctrls: List[int], qargs: List[int], rest: List[int] ) -> Dict[Tuple[float, float, float, float], List[int]]: """Takes a set of ctrls and qargs and returns a noise model for all qubits.""" pass From 24bdee95483fcac603bc813f7ad4b2145f4b3267 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 10:14:40 +0200 Subject: [PATCH 02/15] Fix pyright in qasm2/parse --- src/bloqade/qasm2/parse/lowering.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/bloqade/qasm2/parse/lowering.py b/src/bloqade/qasm2/parse/lowering.py index 5b1d9d59..caa67829 100644 --- a/src/bloqade/qasm2/parse/lowering.py +++ b/src/bloqade/qasm2/parse/lowering.py @@ -85,6 +85,10 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue: stmt = expr.ConstInt(value=value) elif isinstance(value, float): stmt = expr.ConstFloat(value=value) + else: + raise lowering.BuildError( + f"Expected value of type float or int, got {type(value)}." + ) state.current_frame.push(stmt) return stmt.result @@ -99,6 +103,8 @@ def visit_MainProgram(self, state: lowering.State[ast.Node], node: ast.MainProgr dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"] elif isinstance(node.header, ast.Kirin): dialects = node.header.dialects + else: + raise lowering.BuildError(f"Unexpected node header {node.header}") for dialect in dialects: if dialect not in allowed: @@ -295,6 +301,8 @@ def visit_Bit(self, state: lowering.State[ast.Node], node: ast.Bit): stmt = core.QRegGet(reg, addr.result) elif reg.type.is_subseteq(CRegType): stmt = core.CRegGet(reg, addr.result) + else: + raise lowering.BuildError(f"Unexpected register type {reg.type}") return state.current_frame.push(stmt).result def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call): From 4aedcf8d7a9b5bebd936c42723f5df7fec3e4d9f Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 10:15:54 +0200 Subject: [PATCH 03/15] Fix pyright in qasm2/passes --- src/bloqade/qasm2/passes/fold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloqade/qasm2/passes/fold.py b/src/bloqade/qasm2/passes/fold.py index afb1b880..fdc47326 100644 --- a/src/bloqade/qasm2/passes/fold.py +++ b/src/bloqade/qasm2/passes/fold.py @@ -19,7 +19,7 @@ from kirin.analysis import const from kirin.dialects import scf, ilist from kirin.ir.method import Method -from kirin.rewrite.abc import RewriteResult +from kirin.rewrite.result import RewriteResult from bloqade.qasm2.dialects import expr From 7d0165244f2faa779e72a4e1a753c6b273a20c6e Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 10:28:45 +0200 Subject: [PATCH 04/15] Hardcode map - eww --- src/bloqade/qasm2/rewrite/native_gates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bloqade/qasm2/rewrite/native_gates.py b/src/bloqade/qasm2/rewrite/native_gates.py index 7d4840c4..4d3599ed 100644 --- a/src/bloqade/qasm2/rewrite/native_gates.py +++ b/src/bloqade/qasm2/rewrite/native_gates.py @@ -70,7 +70,7 @@ def _circuit_diagram_info_(self, args): return "*", "CU" -def around(val): +def around(val) -> float: return float(np.around(val, 14)) @@ -78,7 +78,7 @@ def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float lam, theta, phi = ( # Z angle, Y angle, then Z angle cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op)) ) - return tuple(map(around, (theta, phi, lam))) + return around(theta), around(phi), around(lam) @dataclass From 5b325e06c155a60a7987edcf8c757f302666ef5b Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 10:38:51 +0200 Subject: [PATCH 05/15] Fix fishy if clause in rewrite_cu3 --- src/bloqade/qasm2/rewrite/native_gates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloqade/qasm2/rewrite/native_gates.py b/src/bloqade/qasm2/rewrite/native_gates.py index 4d3599ed..6f7bf851 100644 --- a/src/bloqade/qasm2/rewrite/native_gates.py +++ b/src/bloqade/qasm2/rewrite/native_gates.py @@ -279,7 +279,7 @@ def rewrite_cu3(self, node: uop.CU3) -> result.RewriteResult: lam = self._get_const_value(node.lam) phi = self._get_const_value(node.phi) - if not all((theta, phi, lam)): + if theta is None or lam is None or phi is None: return result.RewriteResult() # cirq.ControlledGate(u3(theta, lambda phi)) From 509055cf6f77546301ff01ff1aa2bd6b821f6d08 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 10:45:33 +0200 Subject: [PATCH 06/15] Fix RewriteResult imports in qasm2 --- src/bloqade/qasm2/rewrite/register.py | 3 +- src/bloqade/qasm2/rewrite/uop_to_parallel.py | 41 ++++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/bloqade/qasm2/rewrite/register.py b/src/bloqade/qasm2/rewrite/register.py index 3784d6f0..bddca2c8 100644 --- a/src/bloqade/qasm2/rewrite/register.py +++ b/src/bloqade/qasm2/rewrite/register.py @@ -1,6 +1,7 @@ from kirin import ir from kirin.dialects import py -from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.rewrite.abc import RewriteRule +from kirin.rewrite.result import RewriteResult from bloqade.qasm2.dialects import core diff --git a/src/bloqade/qasm2/rewrite/uop_to_parallel.py b/src/bloqade/qasm2/rewrite/uop_to_parallel.py index aa7d5bb7..6a69c970 100644 --- a/src/bloqade/qasm2/rewrite/uop_to_parallel.py +++ b/src/bloqade/qasm2/rewrite/uop_to_parallel.py @@ -3,9 +3,10 @@ from dataclasses import field, dataclass from kirin import ir -from kirin.rewrite import abc as rewrite_abc from kirin.dialects import py, ilist +from kirin.rewrite.abc import RewriteRule from kirin.analysis.const import lattice +from kirin.rewrite.result import RewriteResult from bloqade.analysis import address from bloqade.qasm2.dialects import uop, core, parallel @@ -14,7 +15,7 @@ class MergePolicyABC(abc.ABC): @abc.abstractmethod - def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + def __call__(self, node: ir.Statement) -> RewriteResult: pass @classmethod @@ -141,10 +142,10 @@ def from_analysis( group_numbers=group_numbers, ) - def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + def __call__(self, node: ir.Statement) -> RewriteResult: if node not in self.group_numbers: - return rewrite_abc.RewriteResult() + return RewriteResult() group_number = self.group_numbers[node] group = self.merge_groups[group_number] @@ -157,9 +158,7 @@ def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult: if self.group_has_merged[group_number]: node.delete() - return rewrite_abc.RewriteResult( - has_done_something=self.group_has_merged[group_number] - ) + return RewriteResult(has_done_something=self.group_has_merged[group_number]) def move_and_collect_qubit_list( self, qargs: List[ir.SSAValue], node: ir.Statement @@ -219,14 +218,14 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]): ctrls.append(stmt.ctrls) qargs.append(stmt.qargs) else: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) ctrls_values = self.move_and_collect_qubit_list(ctrls, node) qargs_values = self.move_and_collect_qubit_list(qargs, node) if ctrls_values is None or qargs_values is None: # give up if we cannot determine the address or cannot move the qubits - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_ctrls = ilist.New(values=ctrls_values) new_qargs = ilist.New(values=qargs_values) @@ -238,7 +237,7 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]): node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]): return self.rewrite_group_u(node, group) @@ -252,13 +251,13 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]): elif isinstance(stmt, parallel.UGate): qargs.append(stmt.qargs) else: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) assert isinstance(node, (uop.UGate, parallel.UGate)) qargs_values = self.move_and_collect_qubit_list(qargs, node) if qargs_values is None: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_qargs = ilist.New(values=qargs_values) new_gate = parallel.UGate( @@ -271,7 +270,7 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]): new_gate.insert_before(node) node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]): qargs = [] @@ -282,14 +281,14 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]): elif isinstance(stmt, parallel.RZ): qargs.append(stmt.qargs) else: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) assert isinstance(node, (uop.RZ, parallel.RZ)) qargs_values = self.move_and_collect_qubit_list(qargs, node) if qargs_values is None: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_qargs = ilist.New(values=qargs_values) new_gate = parallel.RZ( @@ -300,7 +299,7 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]): new_gate.insert_before(node) node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]): qargs = [] @@ -310,13 +309,13 @@ def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]): qargs_values = self.move_and_collect_qubit_list(qargs, node) if qargs_values is None: - return rewrite_abc.RewriteResult(has_done_something=False) + return RewriteResult(has_done_something=False) new_node = uop.Barrier(qargs=qargs_values) new_node.insert_before(node) node.delete() - return rewrite_abc.RewriteResult(has_done_something=True) + return RewriteResult(has_done_something=True) class GreedyMixin(MergePolicyABC): @@ -385,11 +384,11 @@ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy): @dataclass -class UOpToParallelRule(rewrite_abc.RewriteRule): +class UOpToParallelRule(RewriteRule): merge_rewriters: Dict[ir.Block | None, MergePolicyABC] - def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: merge_rewriter = self.merge_rewriters.get( - node.parent_block, lambda _: rewrite_abc.RewriteResult() + node.parent_block, lambda _: RewriteResult() ) return merge_rewriter(node) From e6643e4407821da1c75e0f4f63de3e74048cddf1 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 11:18:23 +0200 Subject: [PATCH 07/15] Explicitly cast to specific addresses in address impls --- src/bloqade/analysis/address/impls.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index a9ae40e8..09dc536b 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -190,7 +190,7 @@ def unwrap( stmt: squin.wire.Unwrap, ): - origin_qubit = frame.get(stmt.qubit) + origin_qubit = frame.get_casted(stmt.qubit, AddressQubit) return (AddressWire(origin_qubit=origin_qubit),) @@ -203,7 +203,10 @@ def apply( ): origin_qubits = tuple( - [frame.get(input_elem).origin_qubit for input_elem in stmt.inputs] + [ + frame.get_casted(input_elem, AddressWire).origin_qubit + for input_elem in stmt.inputs + ] ) new_address_wires = tuple( [AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits] From 57b323cef2e2d05f4cdf667169e4ff871147e14e Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 11:20:59 +0200 Subject: [PATCH 08/15] Explicitly freeze qbraid schema classes --- src/bloqade/qbraid/schema.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/bloqade/qbraid/schema.py b/src/bloqade/qbraid/schema.py index ae103cd9..5578e97a 100644 --- a/src/bloqade/qbraid/schema.py +++ b/src/bloqade/qbraid/schema.py @@ -9,7 +9,7 @@ class Operation(BaseModel, frozen=True, extra="forbid"): op_type: str = Field(init=False) -class CZ(Operation): +class CZ(Operation, frozen=True): """A CZ gate operation. Fields: @@ -22,7 +22,7 @@ class CZ(Operation): participants: Tuple[Union[Tuple[int], Tuple[int, int]], ...] -class GlobalRz(Operation): +class GlobalRz(Operation, frozen=True): """GlobalRz operation. Fields: @@ -34,7 +34,7 @@ class GlobalRz(Operation): phi: float -class GlobalW(Operation): +class GlobalW(Operation, frozen=True): """GlobalW operation. Fields: @@ -48,7 +48,7 @@ class GlobalW(Operation): phi: float -class LocalRz(Operation): +class LocalRz(Operation, frozen=True): """LocalRz operation. Fields: @@ -63,7 +63,7 @@ class LocalRz(Operation): phi: float -class LocalW(Operation): +class LocalW(Operation, frozen=True): """LocalW operation. Fields: @@ -80,7 +80,7 @@ class LocalW(Operation): phi: float -class Measurement(Operation): +class Measurement(Operation, frozen=True): """Measurement operation. Fields: @@ -106,7 +106,7 @@ class ErrorModel(BaseModel, frozen=True, extra="forbid"): error_model_type: str = Field(init=False) -class PauliErrorModel(ErrorModel): +class PauliErrorModel(ErrorModel, frozen=True): """Pauli error model. Fields: @@ -131,7 +131,7 @@ class ErrorOperation(BaseModel, Generic[ErrorModelType], frozen=True, extra="for survival_prob: Tuple[float, ...] -class CZError(ErrorOperation[ErrorModelType]): +class CZError(ErrorOperation[ErrorModelType], frozen=True): """CZError operation. Fields: @@ -149,7 +149,7 @@ class CZError(ErrorOperation[ErrorModelType]): single_error: ErrorModelType -class SingleQubitError(ErrorOperation[ErrorModelType]): +class SingleQubitError(ErrorOperation[ErrorModelType], frozen=True): """SingleQubitError operation. Fields: From f446c5ab740607b5faa28f11985c79eab70b4ea8 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 11:26:26 +0200 Subject: [PATCH 09/15] Add get_presemt_trait method to ensure type --- src/bloqade/squin/analysis/nsites/analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bloqade/squin/analysis/nsites/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py index 24930ba8..52bd3091 100644 --- a/src/bloqade/squin/analysis/nsites/analysis.py +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -23,11 +23,11 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): if method is not None: return method(self, frame, stmt) elif stmt.has_trait(HasSites): - has_sites_trait = stmt.get_trait(HasSites) + has_sites_trait = stmt.get_present_trait(HasSites) sites = has_sites_trait.get_sites(stmt) return (NumberSites(sites=sites),) elif stmt.has_trait(FixedSites): - sites_trait = stmt.get_trait(FixedSites) + sites_trait = stmt.get_present_trait(FixedSites) return (NumberSites(sites=sites_trait.data),) else: return (NoSites(),) From ea5e94ea21011fad0e5ea1f6acfa69b62d44d34f Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 11:41:09 +0200 Subject: [PATCH 10/15] Assert some types in animation base --- src/bloqade/visual/animation/base.py | 40 +++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/bloqade/visual/animation/base.py b/src/bloqade/visual/animation/base.py index 870b7322..08ec9825 100644 --- a/src/bloqade/visual/animation/base.py +++ b/src/bloqade/visual/animation/base.py @@ -45,10 +45,12 @@ def get_artists(self) -> Tuple[Any]: class GlobalGateArtist(GateArtist): mpl_obj: mpatches.Rectangle - def __init__(self, mpl_ax: Any, xmin, ymin, width, height, color): + def __init__( + self, mpl_ax: Any, xmin: float, ymin: float, width: float, height: float, color + ): super().__init__(mpl_ax) rc = mpatches.Rectangle( - [xmin, ymin], width, height, color=color, alpha=0.6, visible=False + (xmin, ymin), width, height, color=color, alpha=0.6, visible=False ) mpl_ax.add_patch(rc) self.mpl_obj = rc @@ -56,7 +58,7 @@ def __init__(self, mpl_ax: Any, xmin, ymin, width, height, color): def clear_data(self) -> None: self.mpl_obj.set_width(0) self.mpl_obj.set_height(0) - self.mpl_obj.set_xy([0, 0]) + self.mpl_obj.set_xy((0, 0)) def get_artists(self) -> Tuple[Any]: return (self.mpl_obj,) @@ -86,7 +88,7 @@ def __init__( self.width = width self.xmin = xmin rc_btm = mpatches.Rectangle( - [xmin, ymin_keepout], + (xmin, ymin_keepout), width, ymin - ymin_keepout, color=color, @@ -97,13 +99,13 @@ def __init__( self.mpl_obj_keepout_btm = rc_btm rc = mpatches.Rectangle( - [xmin, ymin], width, ymax - ymin, color=color, alpha=0.6, visible=False + (xmin, ymin), width, ymax - ymin, color=color, alpha=0.6, visible=False ) mpl_ax.add_patch(rc) self.mpl_obj = rc rc_top = mpatches.Rectangle( - [xmin, ymax], + (xmin, ymax), width, ymax_keepout - ymax, color=color, @@ -116,31 +118,31 @@ def __init__( def clear_data(self) -> None: self.mpl_obj.set_width(0) self.mpl_obj.set_height(0) - self.mpl_obj.set_xy([0, 0]) + self.mpl_obj.set_xy((0, 0)) self.mpl_obj_keepout_top.set_width(0) self.mpl_obj_keepout_top.set_height(0) - self.mpl_obj_keepout_top.set_xy([0, 0]) + self.mpl_obj_keepout_top.set_xy((0, 0)) self.mpl_obj_keepout_btm.set_width(0) self.mpl_obj_keepout_btm.set_height(0) - self.mpl_obj_keepout_btm.set_xy([0, 0]) + self.mpl_obj_keepout_btm.set_xy((0, 0)) - def get_artists(self) -> Tuple[Any]: + def get_artists(self) -> Tuple[Any, ...]: return (self.mpl_obj, self.mpl_obj_keepout_top, self.mpl_obj_keepout_btm) def update_data(self, ymin, ymax, ymin_keepout, ymax_keepout): self.mpl_obj.set_height(ymax - ymin) self.mpl_obj.set_width(self.width) - self.mpl_obj.set_xy([self.xmin, ymin]) + self.mpl_obj.set_xy((self.xmin, ymin)) self.mpl_obj_keepout_top.set_height(ymax_keepout - ymax) self.mpl_obj_keepout_top.set_width(self.width) - self.mpl_obj_keepout_top.set_xy([self.xmin, ymax]) + self.mpl_obj_keepout_top.set_xy((self.xmin, ymax)) self.mpl_obj_keepout_btm.set_height(ymin - ymin_keepout) self.mpl_obj_keepout_btm.set_width(self.width) - self.mpl_obj_keepout_btm.set_xy([self.xmin, ymin_keepout]) + self.mpl_obj_keepout_btm.set_xy((self.xmin, ymin_keepout)) def set_visible(self, visible: bool): self.mpl_obj.set_visible(visible) @@ -194,11 +196,19 @@ def not_defined(self): ) @property - def width(self): + def width(self) -> float: + if self.xmax is None or self.xmin is None: + raise ValueError( + "Can't return width of FOV as either xmin or xmax are undefined" + ) return self.xmax - self.xmin @property - def height(self): + def height(self) -> float: + if self.ymax is None or self.ymin is None: + raise ValueError( + "Can't return width of FOV as either ymin or ymax are undefined" + ) return self.ymax - self.ymin def to_json(self): From 6c0c6e642f35e896fa3f5a8fc8bf9ce179a49915 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 12:49:32 +0200 Subject: [PATCH 11/15] Explicitly cast addresses to QubitAddress in heuristic_noise --- src/bloqade/qasm2/rewrite/heuristic_noise.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/bloqade/qasm2/rewrite/heuristic_noise.py b/src/bloqade/qasm2/rewrite/heuristic_noise.py index 6180ccef..27167cad 100644 --- a/src/bloqade/qasm2/rewrite/heuristic_noise.py +++ b/src/bloqade/qasm2/rewrite/heuristic_noise.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, cast from dataclasses import field, dataclass from kirin import ir @@ -226,8 +226,12 @@ def rewrite_parallel_cz_gate(self, node: parallel.CZ): and isinstance(qargs, address.AddressTuple) and all(isinstance(addr, address.AddressQubit) for addr in qargs.data) ): - ctrl_qubits = list(map(lambda addr: addr.data, ctrls.data)) - qarg_qubits = list(map(lambda addr: addr.data, qargs.data)) + ctrl_qubits = list( + map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data) + ) + qarg_qubits = list( + map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data) + ) rest = sorted( set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits) ) From 653bc6f84e455395b67d2177d1eeef4a7d06c9d3 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 12:51:04 +0200 Subject: [PATCH 12/15] Change OperatorType from TypeVar to union --- src/bloqade/qbraid/schema.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/bloqade/qbraid/schema.py b/src/bloqade/qbraid/schema.py index 5578e97a..54eed1c2 100644 --- a/src/bloqade/qbraid/schema.py +++ b/src/bloqade/qbraid/schema.py @@ -95,9 +95,7 @@ class Measurement(Operation, frozen=True): participants: Tuple[int, ...] -OperationType = TypeVar( - "OperationType", bound=Union[CZ, GlobalRz, GlobalW, LocalRz, LocalW, Measurement] -) +OperationType = CZ | GlobalRz | GlobalW | LocalRz | LocalW | Measurement class ErrorModel(BaseModel, frozen=True, extra="forbid"): From 07c6380e688ffa64432e5a8c2d8114324f41ce22 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 16 Apr 2025 16:43:17 +0200 Subject: [PATCH 13/15] Update src/bloqade/analysis/address/impls.py Co-authored-by: Phillip Weinberg --- src/bloqade/analysis/address/impls.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index 09dc536b..b4035be0 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -190,9 +190,13 @@ def unwrap( stmt: squin.wire.Unwrap, ): - origin_qubit = frame.get_casted(stmt.qubit, AddressQubit) + origin_qubit = frame.get(stmt.qubit) + + if isintance(origin_qubit, AddressQubit): - return (AddressWire(origin_qubit=origin_qubit),) + return (AddressWire(origin_qubit=origin_qubit),) + else: + return (Address.top(), ) @interp.impl(squin.wire.Apply) def apply( From d8a0534b571f659e3a0294ea64b18f30c858bc89 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 18 Apr 2025 08:46:45 +0200 Subject: [PATCH 14/15] Fix typo --- src/bloqade/analysis/address/impls.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index b4035be0..da376628 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -191,12 +191,11 @@ def unwrap( ): origin_qubit = frame.get(stmt.qubit) - - if isintance(origin_qubit, AddressQubit): + if isinstance(origin_qubit, AddressQubit): return (AddressWire(origin_qubit=origin_qubit),) else: - return (Address.top(), ) + return (Address.top(),) @interp.impl(squin.wire.Apply) def apply( From 29da51eb093790e597078eb199146a12e815ed47 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 18 Apr 2025 08:49:09 +0200 Subject: [PATCH 15/15] Address Phil's last comment --- src/bloqade/analysis/address/impls.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index da376628..ee5d8414 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -204,17 +204,7 @@ def apply( frame: ForwardFrame[Address], stmt: squin.wire.Apply, ): - - origin_qubits = tuple( - [ - frame.get_casted(input_elem, AddressWire).origin_qubit - for input_elem in stmt.inputs - ] - ) - new_address_wires = tuple( - [AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits] - ) - return new_address_wires + return frame.get_values(stmt.inputs) @squin.qubit.dialect.register(key="qubit.address")