@@ -922,6 +922,9 @@ def get_episode_info(info: Info):
922922
923923{ code (step_info .task_info )}
924924
925+ **Terminated or Truncated:**
926+ { code (f"Terminated: { step_info .terminated } , Truncated: { step_info .truncated } " )}
927+
925928**exp_dir:**
926929
927930<small style="line-height: 1; margin: 0; padding: 0;">{ code (exp_dir_str )} </small>"""
@@ -1243,8 +1246,17 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12431246 warning ("No step info to plot" )
12441247 return None
12451248
1246- # this allows to pop labels to make sure we don't use more than 1 for the legend
1247- labels = ["reset" , "env" , "agent" , "exec action" , "action error" ]
1249+ # Updated labels to include new profiling stages
1250+ labels = [
1251+ "reset" ,
1252+ "env" ,
1253+ "agent" ,
1254+ "exec action" ,
1255+ "action error" ,
1256+ "wait for page" ,
1257+ "validation" ,
1258+ "get observation" ,
1259+ ]
12481260 labels = {e : e for e in labels }
12491261
12501262 colors = plt .get_cmap ("tab20c" ).colors
@@ -1253,6 +1265,7 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12531265 all_times = []
12541266 step_times = []
12551267 for i , step_info in progress_fn (list (enumerate (step_info_list )), desc = "Building plot." ):
1268+ assert isinstance (step_info , StepInfo ), f"Expected StepInfo, got { type (step_info )} "
12561269 step = step_info .step
12571270
12581271 prof = deepcopy (step_info .profiling )
@@ -1274,6 +1287,39 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12741287 label = labels .pop ("exec action" , None )
12751288 add_patch (ax , prof .action_exec_start , prof .action_exec_stop , colors [3 ], label )
12761289
1290+ # NEW: Add wait for page loading visualization
1291+ if (
1292+ hasattr (prof , "wait_for_page_loading_start" )
1293+ and prof .wait_for_page_loading_start > 0
1294+ ):
1295+ add_patch (
1296+ ax ,
1297+ prof .wait_for_page_loading_start ,
1298+ prof .wait_for_page_loading_stop ,
1299+ colors [19 ],
1300+ labels .pop ("wait for page" , None ),
1301+ )
1302+
1303+ # NEW: Add validation visualization
1304+ if hasattr (prof , "validation_start" ) and prof .validation_start > 0 :
1305+ add_patch (
1306+ ax ,
1307+ prof .validation_start ,
1308+ prof .validation_stop ,
1309+ colors [8 ],
1310+ labels .pop ("validation" , None ),
1311+ )
1312+
1313+ # NEW: Add get observation visualization
1314+ if hasattr (prof , "get_observation_start" ) and prof .get_observation_start > 0 :
1315+ add_patch (
1316+ ax ,
1317+ prof .get_observation_start ,
1318+ prof .get_observation_stop ,
1319+ colors [12 ],
1320+ labels .pop ("get observation" , None ),
1321+ )
1322+
12771323 try :
12781324 next_step_error = step_info_list [i + 1 ].obs ["last_action_error" ]
12791325 except (IndexError , KeyError , TypeError ):
@@ -1340,7 +1386,6 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
13401386
13411387 ax .set_ylim (0 , 1 )
13421388 ax .set_xlim (0 , max (all_times ) + 1 )
1343- # plt.gca().autoscale()
13441389
13451390 ax .set_xlabel ("Time" )
13461391 ax .set_yticks ([])
@@ -1349,7 +1394,7 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
13491394 ax .legend (
13501395 loc = "upper center" ,
13511396 bbox_to_anchor = (0.5 , 1.2 ),
1352- ncol = 5 ,
1397+ ncol = 8 , # Updated to accommodate new labels
13531398 frameon = True ,
13541399 )
13551400
0 commit comments