Skip to content

Commit 2cac9e6

Browse files
authored
Transversal Rewrites. (#54)
* refactor gemini arch file structure * refactor rewrite for simulation * adding rewrite rule for moves + refactor * refactor LocialInitialize * adding measurement rewrites * removing print
1 parent 4d70c13 commit 2cac9e6

File tree

10 files changed

+174
-70
lines changed

10 files changed

+174
-70
lines changed

demo/ghz_moves_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def compile_and_visualize(mt: ir.Method, interactive=True):
5656
visualize.debugger(mt, arch_spec, interactive=interactive, atom_marker="s")
5757

5858

59-
# compile_and_visualize(log_depth_ghz)
60-
compile_and_visualize(ghz_optimal, interactive=True)
59+
compile_and_visualize(log_depth_ghz)
60+
# compile_and_visualize(ghz_optimal)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .spec import get_arch_spec as get_arch_spec
2+
from .upstream import SpecializeGemini as SpecializeGemini
Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,15 @@
11
from dataclasses import dataclass
22

3-
from kirin import decl, ir, rewrite, types
4-
from kirin.decl import info
3+
from kirin import ir
54
from kirin.dialects import ilist, py
65
from kirin.rewrite import abc as rewrite_abc
76

87
from bloqade.lanes.dialects import move
98
from bloqade.lanes.layout.encoding import (
10-
Direction,
119
MoveType,
1210
)
1311

14-
from .impls import generate_arch
15-
16-
17-
def get_arch_spec():
18-
return generate_arch(hypercube_dims=1, word_size_y=5)
19-
20-
21-
dialect = ir.Dialect("gemini.logical")
22-
23-
24-
@decl.statement(dialect=dialect)
25-
class Fill(ir.Statement):
26-
27-
filled: ir.SSAValue = info.argument(
28-
ilist.IListType[types.Tuple[types.Float, types.Float], types.Any]
29-
)
30-
vacant: ir.ResultValue = info.result(
31-
ilist.IListType[types.Tuple[types.Float, types.Float], types.Any]
32-
)
33-
34-
35-
NumRows = types.TypeVar("NumRows")
36-
37-
38-
@decl.statement(dialect=dialect)
39-
class LogicalInitialize(ir.Statement):
40-
y_locs: ir.SSAValue = info.argument(ilist.IListType[types.Float, NumRows])
41-
x_locs: ir.SSAValue = info.argument(
42-
ilist.IListType[ilist.IListType[types.Float, types.Any], NumRows]
43-
)
44-
45-
46-
@decl.statement(dialect=dialect)
47-
class SiteBusMove(ir.Statement):
48-
y_mask: ir.SSAValue = info.argument(ilist.IListType[types.Bool, types.Literal(5)])
49-
word: int = info.attribute()
50-
bus_id: int = info.attribute()
51-
direction: Direction = info.attribute()
52-
53-
54-
@decl.statement(dialect=dialect)
55-
class WordBusMove(ir.Statement):
56-
y_mask: ir.SSAValue = info.argument(ilist.IListType[types.Bool, types.Literal(5)])
57-
direction: Direction = info.attribute()
12+
from . import stmts
5813

5914

6015
@dataclass
@@ -84,7 +39,7 @@ def rewrite_Statement(self, node: ir.Statement):
8439

8540
if move_type is MoveType.SITE:
8641
node.replace_by(
87-
SiteBusMove(
42+
stmts.SiteBusMove(
8843
y_mask_ref,
8944
word=word,
9045
bus_id=bus_id,
@@ -93,7 +48,7 @@ def rewrite_Statement(self, node: ir.Statement):
9348
)
9449
elif move_type is MoveType.WORD:
9550
node.replace_by(
96-
WordBusMove(
51+
stmts.WordBusMove(
9752
y_mask_ref,
9853
direction=direction,
9954
)
@@ -102,16 +57,3 @@ def rewrite_Statement(self, node: ir.Statement):
10257
raise AssertionError("Unsupported move type for rewrite")
10358

10459
return rewrite_abc.RewriteResult(has_done_something=True)
105-
106-
107-
class SpecializeGemini:
108-
109-
def emit(self, mt: ir.Method, no_raise=True) -> ir.Method:
110-
out = mt.similar(dialects=mt.dialects.add(dialect))
111-
112-
rewrite.Walk(RewriteMoves()).rewrite(out.code)
113-
114-
if not no_raise:
115-
out.verify()
116-
117-
return out

src/bloqade/lanes/arch/gemini/logical/simulation/__init__.py

Whitespace-only changes.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from dataclasses import dataclass, replace
2+
from itertools import chain
3+
from typing import Callable, Iterator, TypeVar
4+
5+
from kirin import ir
6+
from kirin.dialects import ilist
7+
from kirin.rewrite import abc as rewrite_abc
8+
9+
from bloqade.lanes.dialects import move, place
10+
from bloqade.lanes.layout.encoding import LaneAddress, LocationAddress
11+
12+
AddressType = TypeVar("AddressType", bound=LocationAddress | LaneAddress)
13+
14+
15+
def physical_word_id(address: AddressType) -> Iterator[AddressType]:
16+
if address.word_id == 0:
17+
yield from (replace(address, word_id=word_id) for word_id in range(7))
18+
elif address.word_id == 1:
19+
yield from (replace(address, word_id=word_id) for word_id in range(9, 16, 1))
20+
else:
21+
yield address
22+
23+
24+
@dataclass
25+
class RewriteLocations(rewrite_abc.RewriteRule):
26+
27+
transform_location: Callable[[LocationAddress], Iterator[LocationAddress]] = (
28+
physical_word_id
29+
)
30+
31+
def rewrite_Statement(self, node: ir.Statement):
32+
if not isinstance(
33+
node, (move.Fill, move.LocalR, move.LocalRz, move.Initialize)
34+
):
35+
return rewrite_abc.RewriteResult()
36+
37+
physical_addresses = tuple(
38+
chain.from_iterable(map(self.transform_location, node.location_addresses))
39+
)
40+
41+
attributes: dict[str, ir.Attribute] = {
42+
"location_addresses": ir.PyAttr(physical_addresses)
43+
}
44+
45+
node.replace_by(node.from_stmt(node, attributes=attributes))
46+
return rewrite_abc.RewriteResult(has_done_something=True)
47+
48+
49+
@dataclass
50+
class RewriteMoves(rewrite_abc.RewriteRule):
51+
transform_lanes: Callable[[LaneAddress], Iterator[LaneAddress]] = physical_word_id
52+
53+
def rewrite_Statement(self, node: ir.Statement):
54+
if not isinstance(node, move.Move):
55+
return rewrite_abc.RewriteResult()
56+
57+
physical_lanes = tuple(
58+
chain.from_iterable(map(self.transform_lanes, node.lanes))
59+
)
60+
61+
node.replace_by(move.Move(lanes=physical_lanes))
62+
63+
return rewrite_abc.RewriteResult(has_done_something=True)
64+
65+
66+
@dataclass
67+
class RewriteGetMeasurementResult(rewrite_abc.RewriteRule):
68+
transform_lanes: Callable[[LocationAddress], Iterator[LocationAddress]] = (
69+
physical_word_id
70+
)
71+
72+
def rewrite_Statement(self, node: ir.Statement):
73+
if not isinstance(node, move.GetMeasurementResult):
74+
return rewrite_abc.RewriteResult()
75+
76+
new_results = []
77+
for address in self.transform_lanes(node.location_address):
78+
new_stmt = move.GetMeasurementResult(
79+
node.measurement_future, location_address=address
80+
)
81+
new_results.append(new_stmt.result)
82+
node.insert_before(new_stmt)
83+
84+
node.replace_by(ilist.New(tuple(new_results)))
85+
86+
return rewrite_abc.RewriteResult(has_done_something=True)
87+
88+
89+
class RewriteLogicalToPhysicalConversion(rewrite_abc.RewriteRule):
90+
"""Note that this rewrite is to be combined with RewriteGetMeasurementResult."""
91+
92+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
93+
if not isinstance(node, place.ConvertToPhysicalMeasurements):
94+
return rewrite_abc.RewriteResult()
95+
96+
node.replace_by(ilist.New(tuple(node.args)))
97+
return rewrite_abc.RewriteResult(has_done_something=True)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ..impls import generate_arch
2+
3+
4+
def get_arch_spec():
5+
return generate_arch(hypercube_dims=1, word_size_y=5)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from bloqade.geometry.dialects import filled
2+
from kirin import decl, ir, types
3+
from kirin.decl import info
4+
from kirin.dialects import ilist
5+
6+
from bloqade.lanes.layout.encoding import Direction
7+
8+
dialect = ir.Dialect("gemini.logical")
9+
10+
11+
@decl.statement(dialect=dialect)
12+
class Fill(ir.Statement):
13+
locations: ir.SSAValue = info.argument(filled.FilledGridType[types.Any, types.Any])
14+
15+
16+
NumGates = types.TypeVar("NumGates")
17+
NumRows = types.TypeVar("NumRows")
18+
NumCols = types.TypeVar("NumCols")
19+
20+
21+
@decl.statement(dialect=dialect)
22+
class LogicalInitialize(ir.Statement):
23+
location_groups: ir.SSAValue = info.argument(
24+
ilist.IListType[filled.FilledGridType[NumRows, NumCols], NumGates]
25+
)
26+
thetas: ir.SSAValue = info.argument(ilist.IListType[types.Float, NumGates])
27+
phis: ir.SSAValue = info.argument(ilist.IListType[types.Float, NumGates])
28+
lams: ir.SSAValue = info.argument(ilist.IListType[types.Float, NumGates])
29+
30+
31+
@decl.statement(dialect=dialect)
32+
class SiteBusMove(ir.Statement):
33+
y_mask: ir.SSAValue = info.argument(ilist.IListType[types.Bool, types.Literal(5)])
34+
word: int = info.attribute()
35+
bus_id: int = info.attribute()
36+
direction: Direction = info.attribute()
37+
38+
39+
@decl.statement(dialect=dialect)
40+
class WordBusMove(ir.Statement):
41+
y_mask: ir.SSAValue = info.argument(ilist.IListType[types.Bool, types.Literal(5)])
42+
direction: Direction = info.attribute()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from kirin import ir, rewrite
2+
3+
from .rewrite import RewriteMoves
4+
from .stmts import dialect
5+
6+
7+
class SpecializeGemini:
8+
9+
def emit(self, mt: ir.Method, no_raise=True) -> ir.Method:
10+
out = mt.similar(dialects=mt.dialects.add(dialect))
11+
12+
rewrite.Walk(RewriteMoves()).rewrite(out.code)
13+
14+
if not no_raise:
15+
out.verify()
16+
17+
return out

src/bloqade/lanes/heuristics/fixed.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,6 @@ def compute_moves(
272272
moves.extend(self.site_moves(groups.get((0, 0), []), 0))
273273
moves.extend(self.site_moves(groups.get((1, 1), []), 1))
274274

275-
for move_group in moves:
276-
print(move_group)
277-
278275
return moves
279276

280277

test/test_gemini_logical.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from bloqade.lanes.arch.gemini import logical
77
from bloqade.lanes.arch.gemini.impls import generate_arch
8+
from bloqade.lanes.arch.gemini.logical.rewrite import RewriteMoves
9+
from bloqade.lanes.arch.gemini.logical.stmts import SiteBusMove
810
from bloqade.lanes.dialects import move
911
from bloqade.lanes.layout.encoding import (
1012
Direction,
@@ -41,7 +43,7 @@ def test_logical_architecture_rewrite_site():
4143
)
4244
)
4345

44-
rewrite_rule = rewrite.Walk(logical.RewriteMoves())
46+
rewrite_rule = rewrite.Walk(RewriteMoves())
4547

4648
rewrite_rule.rewrite(test_block)
4749

@@ -50,7 +52,7 @@ def test_logical_architecture_rewrite_site():
5052
const_list := py.Constant(ilist.IList([True, True, True, True, False]))
5153
)
5254
expected_block.stmts.append(
53-
logical.SiteBusMove(
55+
SiteBusMove(
5456
y_mask=const_list.result,
5557
word=0,
5658
bus_id=0,
@@ -66,7 +68,7 @@ def test_logical_architecture_rewrite_site_no_lanes():
6668

6769
test_block.stmts.append(move.Move(lanes=()))
6870

69-
rewrite_rule = rewrite.Walk(logical.RewriteMoves())
71+
rewrite_rule = rewrite.Walk(RewriteMoves())
7072
result = rewrite_rule.rewrite(test_block)
7173
assert not result.has_done_something
7274

0 commit comments

Comments
 (0)