Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/bloqade/lanes/dialects/circuit.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
from kirin import ir, types
from kirin.decl import info, statement

from .execute import ExitLowLevel, LowLevelStmt
from .execute import ExitLowLevel, QuantumStmt

dialect = ir.Dialect(name="lowlevel.circuit")


@statement(dialect=dialect)
class CZ(LowLevelStmt):
pairs: tuple[tuple[int, int], ...] = info.attribute()
class CZ(QuantumStmt):
targets: tuple[int, ...] = info.attribute()
controls: tuple[int, ...] = info.attribute()


@statement(dialect=dialect)
class R(LowLevelStmt):
inputs: tuple[int, ...] = info.attribute()
class R(QuantumStmt):
qubits: tuple[int, ...] = info.attribute()

axis_angle: ir.SSAValue = info.argument(type=types.Float)
rotation_angle: ir.SSAValue = info.argument(type=types.Float)


@statement(dialect=dialect)
class Rz(LowLevelStmt):
inputs: tuple[int, ...] = info.attribute()
class Rz(QuantumStmt):
qubits: tuple[int, ...] = info.attribute()
rotation_angle: ir.SSAValue = info.argument(type=types.Float)


Expand Down
13 changes: 1 addition & 12 deletions src/bloqade/lanes/dialects/cpu.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
from kirin import ir, types
from kirin.decl import info, statement

from .execute import LowLevelStmt

dialect = ir.Dialect(name="lowlevel.cpu")


@statement(dialect=dialect)
class StaticFloat(LowLevelStmt):
class StaticFloat(ir.Statement):
traits = frozenset({ir.ConstantLike()})

value: float = info.attribute(type=types.Float)
result: ir.ResultValue = info.result(type=types.Float)


@statement(dialect=dialect)
class StaticInt(LowLevelStmt):
traits = frozenset({ir.ConstantLike()})

value: int = info.attribute(type=types.Int)

result: ir.ResultValue = info.result(type=types.Int)
7 changes: 3 additions & 4 deletions src/bloqade/lanes/dialects/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@


@statement
class LowLevelStmt(ir.Statement):
class QuantumStmt(ir.Statement):
"""This is a base class for all low level statements."""

state_before: ir.SSAValue = info.argument(StateType)
result: ir.ResultValue = info.result(StateType)
state_after: ir.ResultValue = info.result(StateType)


@statement
Expand All @@ -39,7 +39,6 @@ class ExecuteLowLevel(ir.Statement):
qubits: tuple[ir.SSAValue, ...] = info.argument(bloqade_types.QubitType)
body: ir.Region = info.region(multi=False)
starting_addresses: tuple[LocationAddress, ...] | None = info.attribute()
measure_result: tuple[ir.ResultValue, ...] = info.result()

def __init__(
self,
Expand Down Expand Up @@ -81,7 +80,7 @@ def check(self) -> None:

stmt = body_block.first_stmt
while stmt is not last_stmt:
if not isinstance(stmt, LowLevelStmt):
if not isinstance(stmt, QuantumStmt):
raise exception.StaticCheckError(
"All statements in ShuttleAtoms body must be ByteCodeStmt"
)
Expand Down
14 changes: 7 additions & 7 deletions src/bloqade/lanes/dialects/move.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,42 @@
from kirin.decl import info, statement

from ..layout.encoding import LocationAddress, MoveType
from .execute import ExitLowLevel, LowLevelStmt
from .execute import ExitLowLevel, QuantumStmt

dialect = ir.Dialect(name="lowlevel.move")


@statement(dialect=dialect)
class CZ(LowLevelStmt):
class CZ(QuantumStmt):
pass


@statement(dialect=dialect)
class LocalR(LowLevelStmt):
class LocalR(QuantumStmt):
physical_addr: tuple[LocationAddress, ...] = info.attribute()

axis_angle: ir.SSAValue = info.argument(type=types.Float)
rotation_angle: ir.SSAValue = info.argument(type=types.Float)


@statement(dialect=dialect)
class GlobalR(LowLevelStmt):
class GlobalR(QuantumStmt):
axis_angle: ir.SSAValue = info.argument(type=types.Float)
rotation_angle: ir.SSAValue = info.argument(type=types.Float)


@statement(dialect=dialect)
class LocalRz(LowLevelStmt):
class LocalRz(QuantumStmt):
physical_addr: tuple[LocationAddress, ...] = info.attribute()
rotation_angle: ir.SSAValue = info.argument(type=types.Float)


@statement(dialect=dialect)
class GlobalRz(LowLevelStmt):
class GlobalRz(QuantumStmt):
rotation_angle: ir.SSAValue = info.argument(type=types.Float)


class Move(LowLevelStmt):
class Move(QuantumStmt):
lanes: tuple[MoveType, ...] = info.attribute()


Expand Down
Empty file.
53 changes: 53 additions & 0 deletions src/bloqade/lanes/passes/canonicalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dataclasses import dataclass

from bloqade.native.dialects.gate import stmts as gates
from kirin import ir, passes, rewrite
from kirin.passes.hint_const import HintConst
from kirin.rewrite import abc


class HoistClassicalStatements(abc.RewriteRule):
"""This rewrite rule shift any classical statements that are pure
(quantum statements are never pure) to the beginning of the block.
swapping the other with quantum statements. This is useful after
rewriting the native operations to placement operations,
so that we can merge the placement regions together.

Note that this rule also works with subroutines that contain
quantum statements because these are also not pure

"""

def is_pure(self, node: ir.Statement) -> bool:
return (
node.has_trait(ir.Pure)
or (maybe_pure := node.get_trait(ir.MaybePure)) is not None
and maybe_pure.is_pure(node)
)

def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
if not (
isinstance(node, (gates.CZ, gates.R, gates.Rz))
and (next_node := node.next_stmt) is not None
and not next_node.has_trait(ir.IsTerminator)
and self.is_pure(next_node)
):
return abc.RewriteResult()

next_node.detach()
next_node.insert_before(node)
return abc.RewriteResult(has_done_something=True)


@dataclass
class CanonicalizeNative(passes.Pass):

def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
result = HintConst(mt.dialects)(mt)
result = (
rewrite.Fixpoint(rewrite.Walk(HoistClassicalStatements()))
.rewrite(mt.code)
.join(result)
)

return result
Empty file.
Loading
Loading