Skip to content

Commit 00776a0

Browse files
authored
Porting rewrites from previous work. (#20)
* porting rewrites * adding rewrite rules * adding canonicalize * adding canonicalize * Adding tests and fixing bugs
1 parent cf74eef commit 00776a0

File tree

10 files changed

+465
-30
lines changed

10 files changed

+465
-30
lines changed

src/bloqade/lanes/dialects/circuit.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
from kirin import ir, types
22
from kirin.decl import info, statement
33

4-
from .execute import ExitLowLevel, LowLevelStmt
4+
from .execute import ExitLowLevel, QuantumStmt
55

66
dialect = ir.Dialect(name="lowlevel.circuit")
77

88

99
@statement(dialect=dialect)
10-
class CZ(LowLevelStmt):
11-
pairs: tuple[tuple[int, int], ...] = info.attribute()
10+
class CZ(QuantumStmt):
11+
targets: tuple[int, ...] = info.attribute()
12+
controls: tuple[int, ...] = info.attribute()
1213

1314

1415
@statement(dialect=dialect)
15-
class R(LowLevelStmt):
16-
inputs: tuple[int, ...] = info.attribute()
16+
class R(QuantumStmt):
17+
qubits: tuple[int, ...] = info.attribute()
1718

1819
axis_angle: ir.SSAValue = info.argument(type=types.Float)
1920
rotation_angle: ir.SSAValue = info.argument(type=types.Float)
2021

2122

2223
@statement(dialect=dialect)
23-
class Rz(LowLevelStmt):
24-
inputs: tuple[int, ...] = info.attribute()
24+
class Rz(QuantumStmt):
25+
qubits: tuple[int, ...] = info.attribute()
2526
rotation_angle: ir.SSAValue = info.argument(type=types.Float)
2627

2728

src/bloqade/lanes/dialects/cpu.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
11
from kirin import ir, types
22
from kirin.decl import info, statement
33

4-
from .execute import LowLevelStmt
5-
64
dialect = ir.Dialect(name="lowlevel.cpu")
75

86

97
@statement(dialect=dialect)
10-
class StaticFloat(LowLevelStmt):
8+
class StaticFloat(ir.Statement):
119
traits = frozenset({ir.ConstantLike()})
1210

1311
value: float = info.attribute(type=types.Float)
1412
result: ir.ResultValue = info.result(type=types.Float)
15-
16-
17-
@statement(dialect=dialect)
18-
class StaticInt(LowLevelStmt):
19-
traits = frozenset({ir.ConstantLike()})
20-
21-
value: int = info.attribute(type=types.Int)
22-
23-
result: ir.ResultValue = info.result(type=types.Int)

src/bloqade/lanes/dialects/execute.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010

1111
@statement
12-
class LowLevelStmt(ir.Statement):
12+
class QuantumStmt(ir.Statement):
1313
"""This is a base class for all low level statements."""
1414

1515
state_before: ir.SSAValue = info.argument(StateType)
16-
result: ir.ResultValue = info.result(StateType)
16+
state_after: ir.ResultValue = info.result(StateType)
1717

1818

1919
@statement
@@ -39,7 +39,6 @@ class ExecuteLowLevel(ir.Statement):
3939
qubits: tuple[ir.SSAValue, ...] = info.argument(bloqade_types.QubitType)
4040
body: ir.Region = info.region(multi=False)
4141
starting_addresses: tuple[LocationAddress, ...] | None = info.attribute()
42-
measure_result: tuple[ir.ResultValue, ...] = info.result()
4342

4443
def __init__(
4544
self,
@@ -81,7 +80,7 @@ def check(self) -> None:
8180

8281
stmt = body_block.first_stmt
8382
while stmt is not last_stmt:
84-
if not isinstance(stmt, LowLevelStmt):
83+
if not isinstance(stmt, QuantumStmt):
8584
raise exception.StaticCheckError(
8685
"All statements in ShuttleAtoms body must be ByteCodeStmt"
8786
)

src/bloqade/lanes/dialects/move.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,42 @@
22
from kirin.decl import info, statement
33

44
from ..layout.encoding import LocationAddress, MoveType
5-
from .execute import ExitLowLevel, LowLevelStmt
5+
from .execute import ExitLowLevel, QuantumStmt
66

77
dialect = ir.Dialect(name="lowlevel.move")
88

99

1010
@statement(dialect=dialect)
11-
class CZ(LowLevelStmt):
11+
class CZ(QuantumStmt):
1212
pass
1313

1414

1515
@statement(dialect=dialect)
16-
class LocalR(LowLevelStmt):
16+
class LocalR(QuantumStmt):
1717
physical_addr: tuple[LocationAddress, ...] = info.attribute()
1818

1919
axis_angle: ir.SSAValue = info.argument(type=types.Float)
2020
rotation_angle: ir.SSAValue = info.argument(type=types.Float)
2121

2222

2323
@statement(dialect=dialect)
24-
class GlobalR(LowLevelStmt):
24+
class GlobalR(QuantumStmt):
2525
axis_angle: ir.SSAValue = info.argument(type=types.Float)
2626
rotation_angle: ir.SSAValue = info.argument(type=types.Float)
2727

2828

2929
@statement(dialect=dialect)
30-
class LocalRz(LowLevelStmt):
30+
class LocalRz(QuantumStmt):
3131
physical_addr: tuple[LocationAddress, ...] = info.attribute()
3232
rotation_angle: ir.SSAValue = info.argument(type=types.Float)
3333

3434

3535
@statement(dialect=dialect)
36-
class GlobalRz(LowLevelStmt):
36+
class GlobalRz(QuantumStmt):
3737
rotation_angle: ir.SSAValue = info.argument(type=types.Float)
3838

3939

40-
class Move(LowLevelStmt):
40+
class Move(QuantumStmt):
4141
lanes: tuple[MoveType, ...] = info.attribute()
4242

4343

src/bloqade/lanes/passes/__init__.py

Whitespace-only changes.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from dataclasses import dataclass
2+
3+
from bloqade.native.dialects.gate import stmts as gates
4+
from kirin import ir, passes, rewrite
5+
from kirin.passes.hint_const import HintConst
6+
from kirin.rewrite import abc
7+
8+
9+
class HoistClassicalStatements(abc.RewriteRule):
10+
"""This rewrite rule shift any classical statements that are pure
11+
(quantum statements are never pure) to the beginning of the block.
12+
swapping the other with quantum statements. This is useful after
13+
rewriting the native operations to placement operations,
14+
so that we can merge the placement regions together.
15+
16+
Note that this rule also works with subroutines that contain
17+
quantum statements because these are also not pure
18+
19+
"""
20+
21+
def is_pure(self, node: ir.Statement) -> bool:
22+
return (
23+
node.has_trait(ir.Pure)
24+
or (maybe_pure := node.get_trait(ir.MaybePure)) is not None
25+
and maybe_pure.is_pure(node)
26+
)
27+
28+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
29+
if not (
30+
isinstance(node, (gates.CZ, gates.R, gates.Rz))
31+
and (next_node := node.next_stmt) is not None
32+
and not next_node.has_trait(ir.IsTerminator)
33+
and self.is_pure(next_node)
34+
):
35+
return abc.RewriteResult()
36+
37+
next_node.detach()
38+
next_node.insert_before(node)
39+
return abc.RewriteResult(has_done_something=True)
40+
41+
42+
@dataclass
43+
class CanonicalizeNative(passes.Pass):
44+
45+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
46+
result = HintConst(mt.dialects)(mt)
47+
result = (
48+
rewrite.Fixpoint(rewrite.Walk(HoistClassicalStatements()))
49+
.rewrite(mt.code)
50+
.join(result)
51+
)
52+
53+
return result

src/bloqade/lanes/rewrite/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)