Skip to content

Commit 0ba5822

Browse files
authored
Adding initial angles for initialize to layout analysis (#47)
* Adding initial angles for initialize to layout analysis * remove exception * removing args
1 parent c73a0e9 commit 0ba5822

File tree

6 files changed

+129
-22
lines changed

6 files changed

+129
-22
lines changed

src/bloqade/lanes/analysis/layout/analysis.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from kirin import ir
66
from kirin.analysis.forward import Forward, ForwardFrame
77
from kirin.lattice import EmptyLattice
8-
from typing_extensions import Self
98

109
from bloqade.lanes.layout.encoding import LocationAddress
1110

@@ -39,12 +38,15 @@ def compute_layout(
3938

4039
@dataclass
4140
class LayoutAnalysis(Forward):
42-
keys = ("circuit.layout",)
41+
keys = ("place.layout",)
4342
lattice = EmptyLattice
4443

4544
heuristic: LayoutHeuristicABC
4645
address_entries: dict[ir.SSAValue, address.Address]
4746
all_qubits: tuple[int, ...] = field(init=False)
47+
thetas: dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
48+
phis: dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
49+
lams: dict[int, ir.SSAValue] = field(default_factory=dict, init=False)
4850
stages: list[tuple[tuple[int, int], ...]] = field(default_factory=list, init=False)
4951
global_address_stack: list[int] = field(default_factory=list, init=False)
5052

@@ -60,9 +62,12 @@ def __post_init__(self) -> None:
6062
)
6163
super().__post_init__()
6264

63-
def initialize(self) -> Self:
65+
def initialize(self):
6466
self.stages.clear()
6567
self.global_address_stack.clear()
68+
self.thetas.clear()
69+
self.phis.clear()
70+
self.lams.clear()
6671
return super().initialize()
6772

6873
def eval_stmt_fallback(self, frame, stmt):
@@ -76,15 +81,36 @@ def add_stage(self, control: tuple[int, ...], target: tuple[int, ...]):
7681
def method_self(self, method: ir.Method):
7782
return EmptyLattice.bottom()
7883

79-
def get_layout_no_raise(self, method: ir.Method) -> tuple[LocationAddress, ...]:
84+
def process_results(self):
85+
layout = self.heuristic.compute_layout(self.all_qubits, self.stages)
86+
init_locations = tuple(
87+
loc
88+
for qubit_id, loc in zip(self.all_qubits, layout)
89+
if qubit_id in self.thetas
90+
)
91+
thetas = tuple(
92+
self.thetas[qubit_id]
93+
for qubit_id in self.all_qubits
94+
if qubit_id in self.thetas
95+
)
96+
phis = tuple(
97+
self.phis[qubit_id] for qubit_id in self.all_qubits if qubit_id in self.phis
98+
)
99+
lams = tuple(
100+
self.lams[qubit_id] for qubit_id in self.all_qubits if qubit_id in self.lams
101+
)
102+
103+
return layout, init_locations, thetas, phis, lams
104+
105+
def get_layout_no_raise(self, method: ir.Method):
80106
"""Get the layout for a given method."""
81107
self.run_no_raise(method)
82-
return self.heuristic.compute_layout(self.all_qubits, self.stages)
108+
return self.process_results()
83109

84-
def get_layout(self, method: ir.Method) -> tuple[LocationAddress, ...]:
110+
def get_layout(self, method: ir.Method):
85111
"""Get the layout for a given method."""
86112
self.run(method)
87-
return self.heuristic.compute_layout(self.all_qubits, self.stages)
113+
return self.process_results()
88114

89115
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
90116
return tuple(EmptyLattice.bottom() for _ in node.results)

src/bloqade/lanes/arch/gemini/logical.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@ def get_arch_spec():
2121
dialect = ir.Dialect("gemini.logical")
2222

2323

24+
@decl.statement(dialect=dialect)
25+
class Fill(ir.Statement):
26+
27+
filled: ir.SSAValue = info.argument(
28+
ilist.IListType[types.Tuple[types.Float, types.Float], types.Any]
29+
)
30+
vacant: ir.ResultValue = info.result(
31+
ilist.IListType[types.Tuple[types.Float, types.Float], types.Any]
32+
)
33+
34+
35+
NumRows = types.TypeVar("NumRows")
36+
37+
38+
@decl.statement(dialect=dialect)
39+
class LogicalInitialize(ir.Statement):
40+
y_locs: ir.SSAValue = info.argument(ilist.IListType[types.Float, NumRows])
41+
x_locs: ir.SSAValue = info.argument(
42+
ilist.IListType[ilist.IListType[types.Float, types.Any], NumRows]
43+
)
44+
45+
2446
@decl.statement(dialect=dialect)
2547
class SiteBusMove(ir.Statement):
2648
y_mask: ir.SSAValue = info.argument(ilist.IListType[types.Bool, types.Literal(5)])

src/bloqade/lanes/dialects/move.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,21 @@
1010
dialect = ir.Dialect(name="lanes.move")
1111

1212

13+
@statement(dialect=dialect)
14+
class Fill(ir.Statement):
15+
traits = frozenset({lowering.FromPythonCall()})
16+
17+
location_addresses: tuple[LocationAddress, ...] = info.attribute()
18+
19+
1320
@statement(dialect=dialect)
1421
class Initialize(ir.Statement):
1522
traits = frozenset({lowering.FromPythonCall()})
23+
1624
location_addresses: tuple[LocationAddress, ...] = info.attribute()
25+
thetas: tuple[ir.SSAValue, ...] = info.argument(type=types.Float)
26+
phis: tuple[ir.SSAValue, ...] = info.argument(type=types.Float)
27+
lams: tuple[ir.SSAValue, ...] = info.argument(type=types.Float)
1728

1829

1930
@statement(dialect=dialect)

src/bloqade/lanes/dialects/place.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def __init__(
3030
)
3131

3232

33+
@statement(dialect=dialect)
34+
class Initialize(QuantumStmt):
35+
qubits: tuple[int, ...] = info.attribute()
36+
37+
theta: ir.SSAValue = info.argument(type=types.Float)
38+
phi: ir.SSAValue = info.argument(type=types.Float)
39+
lam: ir.SSAValue = info.argument(type=types.Float)
40+
41+
3342
@statement(dialect=dialect)
3443
class CZ(QuantumStmt):
3544
targets: tuple[int, ...] = info.attribute()
@@ -215,7 +224,7 @@ def end_measure(
215224
return (new_state,) + (EmptyLattice.bottom(),) * len(stmt.qubits)
216225

217226

218-
@dialect.register(key="circuit.layout")
227+
@dialect.register(key="place.layout")
219228
class InitialLayoutMethods(interp.MethodTable):
220229

221230
@interp.impl(CZ)
@@ -229,6 +238,21 @@ def cz(
229238

230239
return ()
231240

241+
@interp.impl(Initialize)
242+
def initialize(
243+
self,
244+
_interp: LayoutAnalysis,
245+
frame: ForwardFrame[EmptyLattice],
246+
stmt: Initialize,
247+
):
248+
for qubit_id in stmt.qubits:
249+
global_qubit_id = _interp.global_address_stack[qubit_id]
250+
_interp.thetas[global_qubit_id] = stmt.theta
251+
_interp.phis[global_qubit_id] = stmt.phi
252+
_interp.lams[global_qubit_id] = stmt.lam
253+
254+
return ()
255+
232256
@interp.impl(StaticPlacement)
233257
def static_circuit(
234258
self,

src/bloqade/lanes/rewrite/place2move.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,30 @@ def rewrite_Statement(self, node: ir.Statement):
292292

293293
@dataclass
294294
class InsertInitialize(RewriteRule):
295+
init_locations: tuple[LocationAddress, ...]
296+
thetas: tuple[ir.SSAValue, ...]
297+
phis: tuple[ir.SSAValue, ...]
298+
lams: tuple[ir.SSAValue, ...]
299+
300+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
301+
if not (len(self.init_locations) > 0 and isinstance(node, func.Function)):
302+
return RewriteResult()
303+
304+
first_stmt = node.body.blocks[0].first_stmt
305+
306+
if first_stmt is None or isinstance(first_stmt, move.Initialize):
307+
return RewriteResult()
308+
move.Initialize(
309+
location_addresses=self.init_locations,
310+
thetas=self.thetas,
311+
phis=self.phis,
312+
lams=self.lams,
313+
).insert_before(first_stmt)
314+
return RewriteResult(has_done_something=True)
315+
316+
317+
@dataclass
318+
class InsertFill(RewriteRule):
295319
initial_layout: tuple[LocationAddress, ...]
296320

297321
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
@@ -300,11 +324,11 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
300324

301325
first_stmt = node.body.blocks[0].first_stmt
302326

303-
if first_stmt is None or isinstance(first_stmt, move.Initialize):
327+
if first_stmt is None or isinstance(first_stmt, move.Fill):
304328
return RewriteResult()
305-
move.Initialize(location_addresses=self.initial_layout).insert_before(
306-
first_stmt
307-
)
329+
330+
move.Fill(location_addresses=self.initial_layout).insert_before(first_stmt)
331+
308332
return RewriteResult(has_done_something=True)
309333

310334

src/bloqade/lanes/upstream.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ def emit(self, mt: Method, no_raise: bool = True):
5252

5353
if no_raise:
5454
address_frame, _ = address.AddressAnalysis(out.dialects).run_no_raise(out)
55-
initial_layout = layout.LayoutAnalysis(
55+
initial_layout, init_locations, thetas, phis, lams = layout.LayoutAnalysis(
5656
out.dialects, self.layout_heristic, address_frame.entries
5757
).get_layout_no_raise(out)
58+
5859
placement_frame, _ = placement.PlacementAnalysis(
5960
out.dialects,
6061
initial_layout,
@@ -63,7 +64,7 @@ def emit(self, mt: Method, no_raise: bool = True):
6364
).run_no_raise(out)
6465
else:
6566
address_frame, _ = address.AddressAnalysis(out.dialects).run(out)
66-
initial_layout = layout.LayoutAnalysis(
67+
initial_layout, init_locations, thetas, phis, lams = layout.LayoutAnalysis(
6768
out.dialects, self.layout_heristic, address_frame.entries
6869
).get_layout(out)
6970
placement_frame, _ = placement.PlacementAnalysis(
@@ -73,15 +74,14 @@ def emit(self, mt: Method, no_raise: bool = True):
7374
self.placement_strategy,
7475
).run(out)
7576

76-
placement_analysis = placement_frame.entries
77-
args = (self.move_scheduler, placement_analysis)
7877
rule = rewrite.Chain(
79-
place2move.InsertInitialize(initial_layout),
80-
place2move.InsertMoves(*args),
81-
place2move.RewriteCZ(*args),
82-
place2move.RewriteR(*args),
83-
place2move.RewriteRz(*args),
84-
place2move.InsertMeasure(*args),
78+
place2move.InsertFill(initial_layout),
79+
place2move.InsertInitialize(init_locations, thetas, phis, lams),
80+
place2move.InsertMoves(self.move_scheduler, placement_frame.entries),
81+
place2move.RewriteCZ(self.move_scheduler, placement_frame.entries),
82+
place2move.RewriteR(self.move_scheduler, placement_frame.entries),
83+
place2move.RewriteRz(self.move_scheduler, placement_frame.entries),
84+
place2move.InsertMeasure(self.move_scheduler, placement_frame.entries),
8585
)
8686
rewrite.Walk(rule).rewrite(out.code)
8787

0 commit comments

Comments
 (0)