@@ -37,6 +37,7 @@ class MatplotlibRenderer(RendererInterface):
3737 ax : Axes = field (default_factory = default_ax , repr = False )
3838
3939 gate_display_options : GateDisplayOptions = field (default_factory = GateDisplayOptions )
40+ arrow_rescale : float = field (default = 1.0 , kw_only = True )
4041
4142 active_x_tones : set [int ] = field (default_factory = set , repr = False , init = False )
4243 active_y_tones : set [int ] = field (default_factory = set , repr = False , init = False )
@@ -59,6 +60,14 @@ def __post_init__(self) -> None:
5960 self .exit_button = Button (exit_ax , "Exit" )
6061 self .exit_button .on_clicked (lambda event : exit ())
6162
63+ @property
64+ def fov_size (self ) -> float :
65+ return np .sqrt ((self .xmax - self .xmin ) ** 2 + (self .ymax - self .ymin ) ** 2 )
66+
67+ @property
68+ def arrow_scale (self ) -> float :
69+ return self .fov_size / 400.0 * self .arrow_rescale
70+
6271 def update_x_bounds (self , y : float ) -> None :
6372 xmin = min (curr_xmin := getattr (self , "xmin" , float ("inf" )), y - 3 )
6473 xmax = max (curr_xmax := getattr (self , "xmax" , float ("-inf" )), y + 3 )
@@ -185,29 +194,34 @@ def render_path(self, pth: path.Path) -> None:
185194 for way_point in path_action .way_points
186195 ]
187196
188- if len (all_waypoints ) == 0 :
197+ num_unique_waypoints = len (set (all_waypoints ))
198+ if num_unique_waypoints < 2 :
189199 return
190200
201+ num_arrows = num_unique_waypoints - 1
191202 first_waypoint = all_waypoints [0 ]
192203 curr_x = first_waypoint .x_positions
193204 curr_y = first_waypoint .y_positions
194205
195206 color_map = plt .get_cmap ("viridis" )
196207
197- num_steps = len (all_waypoints ) - 1
198208 step = 0
199209
200210 x_tones = np .array (pth .x_tones )
201211 y_tones = np .array (pth .y_tones )
202212
203213 x = all_waypoints [0 ].x_positions
204214 y = all_waypoints [0 ].y_positions
215+ self .clear_paths ()
216+ self .show ()
205217
206218 for action in pth .path :
207219 if isinstance (action , taskgen .WayPointsAction ):
208- for way_point in action .way_points :
209- x = way_point .x_positions
210- y = way_point .y_positions
220+ for start , end in zip (action .way_points [:- 1 ], action .way_points [1 :]):
221+ x = end .x_positions
222+ y = end .y_positions
223+ curr_x = start .x_positions
224+ curr_y = start .y_positions
211225
212226 for (x_tone , x_start , x_end ), (y_tone , y_start , y_end ) in product (
213227 zip (pth .x_tones , curr_x , x ), zip (pth .y_tones , curr_y , y )
@@ -226,14 +240,14 @@ def render_path(self, pth: path.Path) -> None:
226240 x_tone in self .active_x_tones
227241 and y_tone in self .active_y_tones
228242 )
229-
243+ p = step / ( num_arrows - 1 ) if num_arrows > 1 else 0.0
230244 line = self .ax .arrow (
231245 x_start ,
232246 y_start ,
233247 dx ,
234248 dy ,
235- width = 0.1 ,
236- color = color_map (step / num_steps ),
249+ width = self . arrow_scale ,
250+ color = color_map (p ),
237251 length_includes_head = True ,
238252 linestyle = "-" if is_on else (0 , (5 , 10 )),
239253 alpha = 1.0 if is_on else 0.5 ,
@@ -243,9 +257,9 @@ def render_path(self, pth: path.Path) -> None:
243257 line .set_edgecolor (line .get_facecolor ())
244258 self .curr_path_lines .append (line )
245259
246- curr_x = x
247- curr_y = y
248- step += 1
260+ if curr_x ! = x or curr_y != y :
261+ step += 1
262+ self . show ()
249263
250264 elif isinstance (action , taskgen .TurnOnAction ):
251265 self .active_x_tones .update (x_tones [action .x_tone_indices ])
0 commit comments