Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def unwrap(
stmt: squin.wire.Unwrap,
):

origin_qubit = frame.get(stmt.qubit)
origin_qubit = frame.get_casted(stmt.qubit, AddressQubit)

return (AddressWire(origin_qubit=origin_qubit),)

Expand All @@ -203,7 +203,10 @@ def apply(
):

origin_qubits = tuple(
[frame.get(input_elem).origin_qubit for input_elem in stmt.inputs]
[
frame.get_casted(input_elem, AddressWire).origin_qubit
for input_elem in stmt.inputs
]
)
new_address_wires = tuple(
[AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also wrong you should just return all the values from the frame because there is no dependency on what the value is.

return frame.get_values(stmt.inputs)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@johnzl-777 how about this one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked to @johnzl-777 in person, he is OK with this change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it!

Expand Down
3 changes: 1 addition & 2 deletions src/bloqade/noise/native/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC):
params: MoveNoiseParams = field(default_factory=MoveNoiseParams)
"""Parameters for calculating move noise."""

@classmethod
@abc.abstractmethod
def parallel_cz_errors(
cls, ctrls: List[int], qargs: List[int], rest: List[int]
self, ctrls: List[int], qargs: List[int], rest: List[int]
) -> Dict[Tuple[float, float, float, float], List[int]]:
"""Takes a set of ctrls and qargs and returns a noise model for all qubits."""
pass
Expand Down
8 changes: 8 additions & 0 deletions src/bloqade/qasm2/parse/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue:
stmt = expr.ConstInt(value=value)
elif isinstance(value, float):
stmt = expr.ConstFloat(value=value)
else:
raise lowering.BuildError(
f"Expected value of type float or int, got {type(value)}."
)
state.current_frame.push(stmt)
return stmt.result

Expand All @@ -99,6 +103,8 @@ def visit_MainProgram(self, state: lowering.State[ast.Node], node: ast.MainProgr
dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"]
elif isinstance(node.header, ast.Kirin):
dialects = node.header.dialects
else:
raise lowering.BuildError(f"Unexpected node header {node.header}")

for dialect in dialects:
if dialect not in allowed:
Expand Down Expand Up @@ -295,6 +301,8 @@ def visit_Bit(self, state: lowering.State[ast.Node], node: ast.Bit):
stmt = core.QRegGet(reg, addr.result)
elif reg.type.is_subseteq(CRegType):
stmt = core.CRegGet(reg, addr.result)
else:
raise lowering.BuildError(f"Unexpected register type {reg.type}")
return state.current_frame.push(stmt).result

def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call):
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/passes/fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from kirin.analysis import const
from kirin.dialects import scf, ilist
from kirin.ir.method import Method
from kirin.rewrite.abc import RewriteResult
from kirin.rewrite.result import RewriteResult

from bloqade.qasm2.dialects import expr

Expand Down
10 changes: 7 additions & 3 deletions src/bloqade/qasm2/rewrite/heuristic_noise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, cast
from dataclasses import field, dataclass

from kirin import ir
Expand Down Expand Up @@ -226,8 +226,12 @@ def rewrite_parallel_cz_gate(self, node: parallel.CZ):
and isinstance(qargs, address.AddressTuple)
and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
):
ctrl_qubits = list(map(lambda addr: addr.data, ctrls.data))
qarg_qubits = list(map(lambda addr: addr.data, qargs.data))
ctrl_qubits = list(
map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
)
qarg_qubits = list(
map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
)
rest = sorted(
set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
)
Expand Down
6 changes: 3 additions & 3 deletions src/bloqade/qasm2/rewrite/native_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def _circuit_diagram_info_(self, args):
return "*", "CU"


def around(val):
def around(val) -> float:
return float(np.around(val, 14))


def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float]:
lam, theta, phi = ( # Z angle, Y angle, then Z angle
cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op))
)
return tuple(map(around, (theta, phi, lam)))
return around(theta), around(phi), around(lam)


@dataclass
Expand Down Expand Up @@ -279,7 +279,7 @@ def rewrite_cu3(self, node: uop.CU3) -> result.RewriteResult:
lam = self._get_const_value(node.lam)
phi = self._get_const_value(node.phi)

if not all((theta, phi, lam)):
if theta is None or lam is None or phi is None:
return result.RewriteResult()

# cirq.ControlledGate(u3(theta, lambda phi))
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/qasm2/rewrite/register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from kirin import ir
from kirin.dialects import py
from kirin.rewrite.abc import RewriteRule, RewriteResult
from kirin.rewrite.abc import RewriteRule
from kirin.rewrite.result import RewriteResult

from bloqade.qasm2.dialects import core

Expand Down
41 changes: 20 additions & 21 deletions src/bloqade/qasm2/rewrite/uop_to_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from dataclasses import field, dataclass

from kirin import ir
from kirin.rewrite import abc as rewrite_abc
from kirin.dialects import py, ilist
from kirin.rewrite.abc import RewriteRule
from kirin.analysis.const import lattice
from kirin.rewrite.result import RewriteResult

from bloqade.analysis import address
from bloqade.qasm2.dialects import uop, core, parallel
Expand All @@ -14,7 +15,7 @@

class MergePolicyABC(abc.ABC):
@abc.abstractmethod
def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
def __call__(self, node: ir.Statement) -> RewriteResult:
pass

@classmethod
Expand Down Expand Up @@ -141,10 +142,10 @@ def from_analysis(
group_numbers=group_numbers,
)

def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
def __call__(self, node: ir.Statement) -> RewriteResult:

if node not in self.group_numbers:
return rewrite_abc.RewriteResult()
return RewriteResult()

group_number = self.group_numbers[node]
group = self.merge_groups[group_number]
Expand All @@ -157,9 +158,7 @@ def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
if self.group_has_merged[group_number]:
node.delete()

return rewrite_abc.RewriteResult(
has_done_something=self.group_has_merged[group_number]
)
return RewriteResult(has_done_something=self.group_has_merged[group_number])

def move_and_collect_qubit_list(
self, qargs: List[ir.SSAValue], node: ir.Statement
Expand Down Expand Up @@ -219,14 +218,14 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):
ctrls.append(stmt.ctrls)
qargs.append(stmt.qargs)
else:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

ctrls_values = self.move_and_collect_qubit_list(ctrls, node)
qargs_values = self.move_and_collect_qubit_list(qargs, node)

if ctrls_values is None or qargs_values is None:
# give up if we cannot determine the address or cannot move the qubits
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

new_ctrls = ilist.New(values=ctrls_values)
new_qargs = ilist.New(values=qargs_values)
Expand All @@ -238,7 +237,7 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):

node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)

def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]):
return self.rewrite_group_u(node, group)
Expand All @@ -252,13 +251,13 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
elif isinstance(stmt, parallel.UGate):
qargs.append(stmt.qargs)
else:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

assert isinstance(node, (uop.UGate, parallel.UGate))
qargs_values = self.move_and_collect_qubit_list(qargs, node)

if qargs_values is None:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

new_qargs = ilist.New(values=qargs_values)
new_gate = parallel.UGate(
Expand All @@ -271,7 +270,7 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
new_gate.insert_before(node)
node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)

def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
qargs = []
Expand All @@ -282,14 +281,14 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
elif isinstance(stmt, parallel.RZ):
qargs.append(stmt.qargs)
else:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

assert isinstance(node, (uop.RZ, parallel.RZ))

qargs_values = self.move_and_collect_qubit_list(qargs, node)

if qargs_values is None:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

new_qargs = ilist.New(values=qargs_values)
new_gate = parallel.RZ(
Expand All @@ -300,7 +299,7 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
new_gate.insert_before(node)
node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)

def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
qargs = []
Expand All @@ -310,13 +309,13 @@ def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
qargs_values = self.move_and_collect_qubit_list(qargs, node)

if qargs_values is None:
return rewrite_abc.RewriteResult(has_done_something=False)
return RewriteResult(has_done_something=False)

new_node = uop.Barrier(qargs=qargs_values)
new_node.insert_before(node)
node.delete()

return rewrite_abc.RewriteResult(has_done_something=True)
return RewriteResult(has_done_something=True)


class GreedyMixin(MergePolicyABC):
Expand Down Expand Up @@ -385,11 +384,11 @@ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy):


@dataclass
class UOpToParallelRule(rewrite_abc.RewriteRule):
class UOpToParallelRule(RewriteRule):
merge_rewriters: Dict[ir.Block | None, MergePolicyABC]

def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
merge_rewriter = self.merge_rewriters.get(
node.parent_block, lambda _: rewrite_abc.RewriteResult()
node.parent_block, lambda _: RewriteResult()
)
return merge_rewriter(node)
22 changes: 10 additions & 12 deletions src/bloqade/qbraid/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Operation(BaseModel, frozen=True, extra="forbid"):
op_type: str = Field(init=False)


class CZ(Operation):
class CZ(Operation, frozen=True):
"""A CZ gate operation.

Fields:
Expand All @@ -22,7 +22,7 @@ class CZ(Operation):
participants: Tuple[Union[Tuple[int], Tuple[int, int]], ...]


class GlobalRz(Operation):
class GlobalRz(Operation, frozen=True):
"""GlobalRz operation.

Fields:
Expand All @@ -34,7 +34,7 @@ class GlobalRz(Operation):
phi: float


class GlobalW(Operation):
class GlobalW(Operation, frozen=True):
"""GlobalW operation.

Fields:
Expand All @@ -48,7 +48,7 @@ class GlobalW(Operation):
phi: float


class LocalRz(Operation):
class LocalRz(Operation, frozen=True):
"""LocalRz operation.

Fields:
Expand All @@ -63,7 +63,7 @@ class LocalRz(Operation):
phi: float


class LocalW(Operation):
class LocalW(Operation, frozen=True):
"""LocalW operation.

Fields:
Expand All @@ -80,7 +80,7 @@ class LocalW(Operation):
phi: float


class Measurement(Operation):
class Measurement(Operation, frozen=True):
"""Measurement operation.

Fields:
Expand All @@ -95,9 +95,7 @@ class Measurement(Operation):
participants: Tuple[int, ...]


OperationType = TypeVar(
"OperationType", bound=Union[CZ, GlobalRz, GlobalW, LocalRz, LocalW, Measurement]
)
OperationType = CZ | GlobalRz | GlobalW | LocalRz | LocalW | Measurement


class ErrorModel(BaseModel, frozen=True, extra="forbid"):
Expand All @@ -106,7 +104,7 @@ class ErrorModel(BaseModel, frozen=True, extra="forbid"):
error_model_type: str = Field(init=False)


class PauliErrorModel(ErrorModel):
class PauliErrorModel(ErrorModel, frozen=True):
"""Pauli error model.

Fields:
Expand All @@ -131,7 +129,7 @@ class ErrorOperation(BaseModel, Generic[ErrorModelType], frozen=True, extra="for
survival_prob: Tuple[float, ...]


class CZError(ErrorOperation[ErrorModelType]):
class CZError(ErrorOperation[ErrorModelType], frozen=True):
"""CZError operation.

Fields:
Expand All @@ -149,7 +147,7 @@ class CZError(ErrorOperation[ErrorModelType]):
single_error: ErrorModelType


class SingleQubitError(ErrorOperation[ErrorModelType]):
class SingleQubitError(ErrorOperation[ErrorModelType], frozen=True):
"""SingleQubitError operation.

Fields:
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/squin/analysis/nsites/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
if method is not None:
return method(self, frame, stmt)
elif stmt.has_trait(HasSites):
has_sites_trait = stmt.get_trait(HasSites)
has_sites_trait = stmt.get_present_trait(HasSites)
sites = has_sites_trait.get_sites(stmt)
return (NumberSites(sites=sites),)
elif stmt.has_trait(FixedSites):
sites_trait = stmt.get_trait(FixedSites)
sites_trait = stmt.get_present_trait(FixedSites)
return (NumberSites(sites=sites_trait.data),)
else:
return (NoSites(),)
Expand Down
Loading
Loading