Skip to content

Commit fcbe54f

Browse files
committed
Updating arrow size and fixing issue with color of arrows during path (#75)
* updating arrow width and fixing issue with color of arrows during path * put continue while rendering path step by step * simplfying interface * make implementation a bit more clear and robust
1 parent e4a0b85 commit fcbe54f

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

src/bloqade/shuttle/visualizer/impl/path.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class PathVisualizerMethods(interp.MethodTable):
2424

2525
@interp.impl(Play)
2626
def play(self, _interp: PathVisualizer[Renderer], frame: interp.Frame, stmt: Play):
27-
_interp.renderer.clear_paths()
2827

2928
path = frame.get(stmt.path)
3029
# path is generated by "main" interpreter so
@@ -40,7 +39,6 @@ def play(self, _interp: PathVisualizer[Renderer], frame: interp.Frame, stmt: Pla
4039
f"Expected a Path or tuple of Paths, got {path}"
4140
)
4241

43-
_interp.renderer.show()
4442
return ()
4543

4644
@interp.impl(Parallel)

src/bloqade/shuttle/visualizer/renderers/matplotlib.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class MatplotlibRenderer(RendererInterface):
3737
ax: Axes = field(default_factory=default_ax, repr=False)
3838

3939
gate_display_options: GateDisplayOptions = field(default_factory=GateDisplayOptions)
40+
arrow_rescale: float = field(default=1.0, kw_only=True)
4041

4142
active_x_tones: set[int] = field(default_factory=set, repr=False, init=False)
4243
active_y_tones: set[int] = field(default_factory=set, repr=False, init=False)
@@ -59,6 +60,14 @@ def __post_init__(self) -> None:
5960
self.exit_button = Button(exit_ax, "Exit")
6061
self.exit_button.on_clicked(lambda event: exit())
6162

63+
@property
64+
def fov_size(self) -> float:
65+
return np.sqrt((self.xmax - self.xmin) ** 2 + (self.ymax - self.ymin) ** 2)
66+
67+
@property
68+
def arrow_scale(self) -> float:
69+
return self.fov_size / 400.0 * self.arrow_rescale
70+
6271
def update_x_bounds(self, y: float) -> None:
6372
xmin = min(curr_xmin := getattr(self, "xmin", float("inf")), y - 3)
6473
xmax = max(curr_xmax := getattr(self, "xmax", float("-inf")), y + 3)
@@ -185,29 +194,34 @@ def render_path(self, pth: path.Path) -> None:
185194
for way_point in path_action.way_points
186195
]
187196

188-
if len(all_waypoints) == 0:
197+
num_unique_waypoints = len(set(all_waypoints))
198+
if num_unique_waypoints < 2:
189199
return
190200

201+
num_arrows = num_unique_waypoints - 1
191202
first_waypoint = all_waypoints[0]
192203
curr_x = first_waypoint.x_positions
193204
curr_y = first_waypoint.y_positions
194205

195206
color_map = plt.get_cmap("viridis")
196207

197-
num_steps = len(all_waypoints) - 1
198208
step = 0
199209

200210
x_tones = np.array(pth.x_tones)
201211
y_tones = np.array(pth.y_tones)
202212

203213
x = all_waypoints[0].x_positions
204214
y = all_waypoints[0].y_positions
215+
self.clear_paths()
216+
self.show()
205217

206218
for action in pth.path:
207219
if isinstance(action, taskgen.WayPointsAction):
208-
for way_point in action.way_points:
209-
x = way_point.x_positions
210-
y = way_point.y_positions
220+
for start, end in zip(action.way_points[:-1], action.way_points[1:]):
221+
x = end.x_positions
222+
y = end.y_positions
223+
curr_x = start.x_positions
224+
curr_y = start.y_positions
211225

212226
for (x_tone, x_start, x_end), (y_tone, y_start, y_end) in product(
213227
zip(pth.x_tones, curr_x, x), zip(pth.y_tones, curr_y, y)
@@ -226,14 +240,14 @@ def render_path(self, pth: path.Path) -> None:
226240
x_tone in self.active_x_tones
227241
and y_tone in self.active_y_tones
228242
)
229-
243+
p = step / (num_arrows - 1) if num_arrows > 1 else 0.0
230244
line = self.ax.arrow(
231245
x_start,
232246
y_start,
233247
dx,
234248
dy,
235-
width=0.1,
236-
color=color_map(step / num_steps),
249+
width=self.arrow_scale,
250+
color=color_map(p),
237251
length_includes_head=True,
238252
linestyle="-" if is_on else (0, (5, 10)),
239253
alpha=1.0 if is_on else 0.5,
@@ -243,9 +257,9 @@ def render_path(self, pth: path.Path) -> None:
243257
line.set_edgecolor(line.get_facecolor())
244258
self.curr_path_lines.append(line)
245259

246-
curr_x = x
247-
curr_y = y
248-
step += 1
260+
if curr_x != x or curr_y != y:
261+
step += 1
262+
self.show()
249263

250264
elif isinstance(action, taskgen.TurnOnAction):
251265
self.active_x_tones.update(x_tones[action.x_tone_indices])

0 commit comments

Comments
 (0)