Skip to content

Commit 4d91bac

Browse files
authored
Debugging Transversal Moves. (#57)
* fixing bugs + updating demo * adding constants to visualization
1 parent 2cac9e6 commit 4d91bac

File tree

4 files changed

+41
-10
lines changed

4 files changed

+41
-10
lines changed

demo/ghz_moves_demo.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from bloqade.native.upstream import SquinToNative
2-
from kirin import ir
2+
from kirin import ir, rewrite
33
from kirin.dialects import ilist
44

55
from bloqade import qubit, squin
66
from bloqade.lanes import visualize
77
from bloqade.lanes.arch.gemini import logical
8+
from bloqade.lanes.arch.gemini.impls import generate_arch
9+
from bloqade.lanes.arch.gemini.logical.simulation import rewrite as sim_rewrite
810
from bloqade.lanes.heuristics import fixed
911
from bloqade.lanes.upstream import NativeToPlace, PlaceToMove
1012

@@ -40,7 +42,7 @@ def ghz_optimal():
4042
squin.broadcast.cx(qs[:5], qs[5:])
4143

4244

43-
def compile_and_visualize(mt: ir.Method, interactive=True):
45+
def compile(mt: ir.Method, transversal: bool = False):
4446
# Compile to move dialect
4547

4648
mt = SquinToNative().emit(mt)
@@ -51,10 +53,26 @@ def compile_and_visualize(mt: ir.Method, interactive=True):
5153
fixed.LogicalMoveScheduler(),
5254
).emit(mt)
5355

54-
arch_spec = logical.get_arch_spec()
56+
if transversal:
57+
rewrite.Walk(
58+
rewrite.Chain(sim_rewrite.RewriteLocations(), sim_rewrite.RewriteMoves())
59+
).rewrite(mt.code)
60+
return mt
5561

56-
visualize.debugger(mt, arch_spec, interactive=interactive, atom_marker="s")
62+
63+
def compile_and_visualize(mt: ir.Method, interactive=True, transversal: bool = False):
64+
# Compile to move dialect
65+
mt = compile(mt, transversal=transversal)
66+
if transversal:
67+
arch_spec = generate_arch(4)
68+
marker = "o"
69+
else:
70+
arch_spec = logical.get_arch_spec()
71+
marker = "s"
72+
73+
visualize.debugger(mt, arch_spec, interactive=interactive, atom_marker=marker)
5774

5875

59-
compile_and_visualize(log_depth_ghz)
60-
# compile_and_visualize(ghz_optimal)
76+
compile_and_visualize(log_depth_ghz, transversal=False)
77+
compile_and_visualize(ghz_optimal, transversal=False)
78+
compile_and_visualize(ghz_optimal, transversal=True)

src/bloqade/lanes/arch/gemini/impls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def site_buses(site_addresses: np.ndarray):
3232
def hypercube_busses(hypercube_dims: int):
3333
word_buses: list[Bus] = []
3434
for shift in range(hypercube_dims):
35-
m = 1 << shift
35+
m = 1 << (hypercube_dims - shift - 1)
3636

3737
srcs = []
3838
dsts = []
@@ -45,6 +45,7 @@ def hypercube_busses(hypercube_dims: int):
4545
dsts.append(dst)
4646

4747
word_buses.append(Bus(tuple(srcs), tuple(dsts)))
48+
4849
return tuple(word_buses)
4950

5051

src/bloqade/lanes/arch/gemini/logical/simulation/rewrite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def physical_word_id(address: AddressType) -> Iterator[AddressType]:
1616
if address.word_id == 0:
1717
yield from (replace(address, word_id=word_id) for word_id in range(7))
1818
elif address.word_id == 1:
19-
yield from (replace(address, word_id=word_id) for word_id in range(9, 16, 1))
19+
yield from (replace(address, word_id=word_id) for word_id in range(8, 15, 1))
2020
else:
2121
yield address
2222

src/bloqade/lanes/visualize.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from itertools import chain
22

33
from kirin import ir
4+
from kirin.dialects import py
45
from matplotlib import figure, pyplot as plt
56
from matplotlib.axes import Axes
67
from matplotlib.widgets import Button
@@ -126,11 +127,22 @@ def get_drawer(mt: ir.Method, arch_spec: ArchSpec, ax: Axes, atom_marker: str =
126127
y_max += 0.1 * y_width
127128

128129
steps: list[tuple[ir.Statement, AtomState]] = []
129-
130+
constants = {}
130131
for stmt in mt.callable_region.walk():
131132
curr_state = frame.atom_state_map.get(stmt)
132133
if isinstance(curr_state, AtomState):
133134
steps.append((stmt, curr_state))
135+
elif isinstance(stmt, py.Constant):
136+
constants[stmt.result] = stmt.value.unwrap()
137+
138+
def stmt_text(stmt: ir.Statement) -> str:
139+
if len(stmt.args) == 0:
140+
return f"{type(stmt).__name__}"
141+
return (
142+
f"{type(stmt).__name__}("
143+
+ ", ".join(f"{constants.get(arg,'missing')}" for arg in stmt.args)
144+
+ ")"
145+
)
134146

135147
def draw(step_index: int):
136148
if len(steps) == 0:
@@ -145,7 +157,7 @@ def draw(step_index: int):
145157
)
146158
curr_state.draw_moves(arch_spec, ax=ax, color="orange")
147159

148-
ax.set_title(f"Step {step_index+1} / {len(steps)}: {type(stmt).__name__}")
160+
ax.set_title(f"Step {step_index+1} / {len(steps)}: {stmt_text(stmt)}")
149161

150162
ax.set_xlim(x_min, x_max)
151163
ax.set_ylim(y_min, y_max)

0 commit comments

Comments
 (0)