Skip to content

Commit 6a0c238

Browse files
committed
improve profiling
1 parent d57b096 commit 6a0c238

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

src/agentlab/analyze/agent_xray.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,9 @@ def get_episode_info(info: Info):
922922
923923
{code(step_info.task_info)}
924924
925+
**Terminated or Truncated:**
926+
{code(f"Terminated: {step_info.terminated}, Truncated: {step_info.truncated}")}
927+
925928
**exp_dir:**
926929
927930
<small style="line-height: 1; margin: 0; padding: 0;">{code(exp_dir_str)}</small>"""
@@ -1243,8 +1246,17 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12431246
warning("No step info to plot")
12441247
return None
12451248

1246-
# this allows to pop labels to make sure we don't use more than 1 for the legend
1247-
labels = ["reset", "env", "agent", "exec action", "action error"]
1249+
# Updated labels to include new profiling stages
1250+
labels = [
1251+
"reset",
1252+
"env",
1253+
"agent",
1254+
"exec action",
1255+
"action error",
1256+
"wait for page",
1257+
"validation",
1258+
"get observation",
1259+
]
12481260
labels = {e: e for e in labels}
12491261

12501262
colors = plt.get_cmap("tab20c").colors
@@ -1253,6 +1265,7 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12531265
all_times = []
12541266
step_times = []
12551267
for i, step_info in progress_fn(list(enumerate(step_info_list)), desc="Building plot."):
1268+
assert isinstance(step_info, StepInfo), f"Expected StepInfo, got {type(step_info)}"
12561269
step = step_info.step
12571270

12581271
prof = deepcopy(step_info.profiling)
@@ -1274,6 +1287,39 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
12741287
label = labels.pop("exec action", None)
12751288
add_patch(ax, prof.action_exec_start, prof.action_exec_stop, colors[3], label)
12761289

1290+
# NEW: Add wait for page loading visualization
1291+
if (
1292+
hasattr(prof, "wait_for_page_loading_start")
1293+
and prof.wait_for_page_loading_start > 0
1294+
):
1295+
add_patch(
1296+
ax,
1297+
prof.wait_for_page_loading_start,
1298+
prof.wait_for_page_loading_stop,
1299+
colors[19],
1300+
labels.pop("wait for page", None),
1301+
)
1302+
1303+
# NEW: Add validation visualization
1304+
if hasattr(prof, "validation_start") and prof.validation_start > 0:
1305+
add_patch(
1306+
ax,
1307+
prof.validation_start,
1308+
prof.validation_stop,
1309+
colors[8],
1310+
labels.pop("validation", None),
1311+
)
1312+
1313+
# NEW: Add get observation visualization
1314+
if hasattr(prof, "get_observation_start") and prof.get_observation_start > 0:
1315+
add_patch(
1316+
ax,
1317+
prof.get_observation_start,
1318+
prof.get_observation_stop,
1319+
colors[12],
1320+
labels.pop("get observation", None),
1321+
)
1322+
12771323
try:
12781324
next_step_error = step_info_list[i + 1].obs["last_action_error"]
12791325
except (IndexError, KeyError, TypeError):
@@ -1340,7 +1386,6 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
13401386

13411387
ax.set_ylim(0, 1)
13421388
ax.set_xlim(0, max(all_times) + 1)
1343-
# plt.gca().autoscale()
13441389

13451390
ax.set_xlabel("Time")
13461391
ax.set_yticks([])
@@ -1349,7 +1394,7 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
13491394
ax.legend(
13501395
loc="upper center",
13511396
bbox_to_anchor=(0.5, 1.2),
1352-
ncol=5,
1397+
ncol=8, # Updated to accommodate new labels
13531398
frameon=True,
13541399
)
13551400

0 commit comments

Comments
 (0)