Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@

origin_qubit = frame.get(stmt.qubit)

return (AddressWire(origin_qubit=origin_qubit),)
if isinstance(origin_qubit, AddressQubit):
return (AddressWire(origin_qubit=origin_qubit),)
else:
return (Address.top(),)

Check warning on line 198 in src/bloqade/analysis/address/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/address/impls.py#L198

Added line #L198 was not covered by tests

@interp.impl(squin.wire.Apply)
def apply(
Expand All @@ -201,14 +204,7 @@
frame: ForwardFrame[Address],
stmt: squin.wire.Apply,
):

origin_qubits = tuple(
[frame.get(input_elem).origin_qubit for input_elem in stmt.inputs]
)
new_address_wires = tuple(
[AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits]
)
return new_address_wires
return frame.get_values(stmt.inputs)


@squin.qubit.dialect.register(key="qubit.address")
Expand Down
3 changes: 1 addition & 2 deletions src/bloqade/noise/native/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC):
params: MoveNoiseParams = field(default_factory=MoveNoiseParams)
"""Parameters for calculating move noise."""

@classmethod
@abc.abstractmethod
def parallel_cz_errors(
cls, ctrls: List[int], qargs: List[int], rest: List[int]
self, ctrls: List[int], qargs: List[int], rest: List[int]
) -> Dict[Tuple[float, float, float, float], List[int]]:
"""Takes a set of ctrls and qargs and returns a noise model for all qubits."""
pass
Expand Down
9 changes: 7 additions & 2 deletions src/bloqade/qasm2/parse/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@
elif isinstance(value, float):
stmt = expr.ConstFloat(value=value)
else:
raise lowering.BuildError(f"Unsupported literal type {type(value)}")

raise lowering.BuildError(

Check warning on line 203 in src/bloqade/qasm2/parse/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/parse/lowering.py#L203

Added line #L203 was not covered by tests
f"Expected value of type float or int, got {type(value)}."
)
state.current_frame.push(stmt)
return stmt.result

Expand All @@ -216,6 +217,8 @@
dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"]
elif isinstance(node.header, ast.Kirin):
dialects = node.header.dialects
else:
raise lowering.BuildError(f"Unexpected node header {node.header}")

Check warning on line 221 in src/bloqade/qasm2/parse/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/parse/lowering.py#L221

Added line #L221 was not covered by tests

for dialect in dialects:
if dialect not in allowed:
Expand Down Expand Up @@ -412,6 +415,8 @@
stmt = core.QRegGet(reg, addr.result)
elif reg.type.is_subseteq(CRegType):
stmt = core.CRegGet(reg, addr.result)
else:
raise lowering.BuildError(f"Unexpected register type {reg.type}")

Check warning on line 419 in src/bloqade/qasm2/parse/lowering.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/parse/lowering.py#L419

Added line #L419 was not covered by tests
return state.current_frame.push(stmt).result

def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call):
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/passes/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from kirin.analysis import const
from kirin.dialects import scf, ilist
from kirin.ir.method import Method
from kirin.rewrite.abc import RewriteResult
from kirin.rewrite.result import RewriteResult

from bloqade.qasm2.dialects import expr

Expand Down
10 changes: 7 additions & 3 deletions src/bloqade/qasm2/rewrite/heuristic_noise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, cast
from dataclasses import field, dataclass

from kirin import ir
Expand Down Expand Up @@ -226,8 +226,12 @@ def rewrite_parallel_cz_gate(self, node: parallel.CZ):
and isinstance(qargs, address.AddressTuple)
and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
):
ctrl_qubits = list(map(lambda addr: addr.data, ctrls.data))
qarg_qubits = list(map(lambda addr: addr.data, qargs.data))
ctrl_qubits = list(
map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
)
qarg_qubits = list(
map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
)
rest = sorted(
set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
)
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/qasm2/rewrite/native_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def _circuit_diagram_info_(self, args):
return "*", "CU"


def around(val):
def around(val) -> float:
return float(np.around(val, 14))


def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float]:
lam, theta, phi = ( # Z angle, Y angle, then Z angle
cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op))
)
return tuple(map(around, (theta, phi, lam)))
return around(theta), around(phi), around(lam)


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/qasm2/rewrite/register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kirin import ir
from kirin.dialects import py
from kirin.rewrite.abc import RewriteRule, RewriteResult
from kirin.rewrite.abc import RewriteRule
from kirin.rewrite.result import RewriteResult

from bloqade.qasm2.dialects import core

Expand Down
41 changes: 20 additions & 21 deletions src/bloqade/qasm2/rewrite/uop_to_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from dataclasses import field, dataclass

from kirin import ir
from kirin.rewrite import abc as rewrite_abc
from kirin.dialects import py, ilist
from kirin.rewrite.abc import RewriteRule
from kirin.analysis.const import lattice
from kirin.rewrite.result import RewriteResult

from bloqade.analysis import address
from bloqade.qasm2.dialects import uop, core, parallel
Expand All @@ -14,7 +15,7 @@

class MergePolicyABC(abc.ABC):
@abc.abstractmethod
def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
def __call__(self, node: ir.Statement) -> RewriteResult:
pass

@classmethod
Expand Down Expand Up @@ -141,10 +142,10 @@
group_numbers=group_numbers,
)

def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
def __call__(self, node: ir.Statement) -> RewriteResult:

if node not in self.group_numbers:
return rewrite_abc.RewriteResult()
return RewriteResult()

group_number = self.group_numbers[node]
group = self.merge_groups[group_number]
Expand All @@ -157,9 +158,7 @@
if self.group_has_merged[group_number]:
node.delete()

return rewrite_abc.RewriteResult(
has_done_something=self.group_has_merged[group_number]
)
return RewriteResult(has_done_something=self.group_has_merged[group_number])

def move_and_collect_qubit_list(
self, qargs: List[ir.SSAValue], node: ir.Statement
Expand Down Expand Up @@ -219,14 +218,14 @@
ctrls.append(stmt.ctrls)
qargs.append(stmt.qargs)
else:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

Check warning on line 221 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L221

Added line #L221 was not covered by tests

ctrls_values = self.move_and_collect_qubit_list(ctrls, node)
qargs_values = self.move_and_collect_qubit_list(qargs, node)

if ctrls_values is None or qargs_values is None:
# give up if we cannot determine the address or cannot move the qubits
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

Check warning on line 228 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L228

Added line #L228 was not covered by tests

new_ctrls = ilist.New(values=ctrls_values)
new_qargs = ilist.New(values=qargs_values)
Expand All @@ -238,7 +237,7 @@

node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)

Check warning on line 240 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L240

Added line #L240 was not covered by tests

def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]):
return self.rewrite_group_u(node, group)
Expand All @@ -252,13 +251,13 @@
elif isinstance(stmt, parallel.UGate):
qargs.append(stmt.qargs)
else:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

Check warning on line 254 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L254

Added line #L254 was not covered by tests

assert isinstance(node, (uop.UGate, parallel.UGate))
qargs_values = self.move_and_collect_qubit_list(qargs, node)

if qargs_values is None:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

Check warning on line 260 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L260

Added line #L260 was not covered by tests

new_qargs = ilist.New(values=qargs_values)
new_gate = parallel.UGate(
Expand All @@ -271,7 +270,7 @@
new_gate.insert_before(node)
node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)

def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
qargs = []
Expand All @@ -282,14 +281,14 @@
elif isinstance(stmt, parallel.RZ):
qargs.append(stmt.qargs)
else:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

Check warning on line 284 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L284

Added line #L284 was not covered by tests

assert isinstance(node, (uop.RZ, parallel.RZ))

qargs_values = self.move_and_collect_qubit_list(qargs, node)

if qargs_values is None:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

Check warning on line 291 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L291

Added line #L291 was not covered by tests

new_qargs = ilist.New(values=qargs_values)
new_gate = parallel.RZ(
Expand All @@ -300,7 +299,7 @@
new_gate.insert_before(node)
node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)

def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
qargs = []
Expand All @@ -310,13 +309,13 @@
qargs_values = self.move_and_collect_qubit_list(qargs, node)

if qargs_values is None:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

Check warning on line 312 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L312

Added line #L312 was not covered by tests

new_node = uop.Barrier(qargs=qargs_values)
new_node.insert_before(node)
node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)

Check warning on line 318 in src/bloqade/qasm2/rewrite/uop_to_parallel.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/qasm2/rewrite/uop_to_parallel.py#L318

Added line #L318 was not covered by tests


class GreedyMixin(MergePolicyABC):
Expand Down Expand Up @@ -385,11 +384,11 @@


@dataclass
class UOpToParallelRule(rewrite_abc.RewriteRule):
class UOpToParallelRule(RewriteRule):
merge_rewriters: Dict[ir.Block | None, MergePolicyABC]

def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
merge_rewriter = self.merge_rewriters.get(
node.parent_block, lambda _: rewrite_abc.RewriteResult()
node.parent_block, lambda _: RewriteResult()
)
return merge_rewriter(node)
22 changes: 10 additions & 12 deletions src/bloqade/qbraid/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Operation(BaseModel, frozen=True, extra="forbid"):
op_type: str = Field(init=False)


class CZ(Operation):
class CZ(Operation, frozen=True):
"""A CZ gate operation.

Fields:
Expand All @@ -22,7 +22,7 @@ class CZ(Operation):
participants: Tuple[Union[Tuple[int], Tuple[int, int]], ...]


class GlobalRz(Operation):
class GlobalRz(Operation, frozen=True):
"""GlobalRz operation.

Fields:
Expand All @@ -34,7 +34,7 @@ class GlobalRz(Operation):
phi: float


class GlobalW(Operation):
class GlobalW(Operation, frozen=True):
"""GlobalW operation.

Fields:
Expand All @@ -48,7 +48,7 @@ class GlobalW(Operation):
phi: float


class LocalRz(Operation):
class LocalRz(Operation, frozen=True):
"""LocalRz operation.

Fields:
Expand All @@ -63,7 +63,7 @@ class LocalRz(Operation):
phi: float


class LocalW(Operation):
class LocalW(Operation, frozen=True):
"""LocalW operation.

Fields:
Expand All @@ -80,7 +80,7 @@ class LocalW(Operation):
phi: float


class Measurement(Operation):
class Measurement(Operation, frozen=True):
"""Measurement operation.

Fields:
Expand All @@ -95,9 +95,7 @@ class Measurement(Operation):
participants: Tuple[int, ...]


OperationType = TypeVar(
"OperationType", bound=Union[CZ, GlobalRz, GlobalW, LocalRz, LocalW, Measurement]
)
OperationType = CZ | GlobalRz | GlobalW | LocalRz | LocalW | Measurement


class ErrorModel(BaseModel, frozen=True, extra="forbid"):
Expand All @@ -106,7 +104,7 @@ class ErrorModel(BaseModel, frozen=True, extra="forbid"):
error_model_type: str = Field(init=False)


class PauliErrorModel(ErrorModel):
class PauliErrorModel(ErrorModel, frozen=True):
"""Pauli error model.

Fields:
Expand All @@ -131,7 +129,7 @@ class ErrorOperation(BaseModel, Generic[ErrorModelType], frozen=True, extra="for
survival_prob: Tuple[float, ...]


class CZError(ErrorOperation[ErrorModelType]):
class CZError(ErrorOperation[ErrorModelType], frozen=True):
"""CZError operation.

Fields:
Expand All @@ -149,7 +147,7 @@ class CZError(ErrorOperation[ErrorModelType]):
single_error: ErrorModelType


class SingleQubitError(ErrorOperation[ErrorModelType]):
class SingleQubitError(ErrorOperation[ErrorModelType], frozen=True):
"""SingleQubitError operation.

Fields:
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/squin/analysis/nsites/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
if method is not None:
return method(self, frame, stmt)
elif stmt.has_trait(HasSites):
has_sites_trait = stmt.get_trait(HasSites)
has_sites_trait = stmt.get_present_trait(HasSites)

Check warning on line 26 in src/bloqade/squin/analysis/nsites/analysis.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/analysis.py#L26

Added line #L26 was not covered by tests
sites = has_sites_trait.get_sites(stmt)
return (NumberSites(sites=sites),)
elif stmt.has_trait(FixedSites):
sites_trait = stmt.get_trait(FixedSites)
sites_trait = stmt.get_present_trait(FixedSites)
return (NumberSites(sites=sites_trait.data),)
else:
return (NoSites(),)
Expand Down
Loading