Skip to content

Commit 15f4c6d

Browse files
committed
xray bugfix
1 parent 219e467 commit 15f4c6d

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

src/agentlab/analyze/agent_xray.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class EpisodeId:
7474
agent_id: str = None
7575
task_name: str = None
7676
seed: int = None
77+
row_index: int = None # unique row index to disambiguate selections
7778

7879

7980
@dataclass
@@ -99,6 +100,24 @@ def update_exp_result(self, episode_id: EpisodeId):
99100
if self.result_df is None or episode_id.task_name is None or episode_id.seed is None:
100101
self.exp_result = None
101102

103+
# Prefer selecting by explicit row index if available
104+
if episode_id.row_index is not None:
105+
tmp_df = self.result_df.reset_index(inplace=False)
106+
tmp_df["_row_index"] = tmp_df.index
107+
sub_df = tmp_df[tmp_df["_row_index"] == episode_id.row_index]
108+
if len(sub_df) == 0:
109+
self.exp_result = None
110+
raise ValueError(f"Could not find episode for row_index: {episode_id.row_index}")
111+
if len(sub_df) > 1:
112+
warning(
113+
f"Found multiple rows for row_index: {episode_id.row_index}. Using the first one."
114+
)
115+
exp_dir = sub_df.iloc[0]["exp_dir"]
116+
print(exp_dir)
117+
self.exp_result = ExpResult(exp_dir)
118+
self.step = 0
119+
return
120+
102121
# find unique row for task_name and seed
103122
result_df = self.agent_df.reset_index(inplace=False)
104123
sub_df = result_df[
@@ -128,16 +147,15 @@ def get_agent_id(self, row: pd.Series):
128147
return agent_id
129148

130149
def filter_agent_id(self, agent_id: list[tuple]):
131-
# query_str = " & ".join([f"`{col}` == {repr(val)}" for col, val in agent_id])
132-
# agent_df = info.result_df.query(query_str)
133-
134-
agent_df = self.result_df.reset_index(inplace=False)
135-
agent_df.set_index(TASK_NAME_KEY, inplace=True)
150+
# Preserve a stable row index to disambiguate selections later
151+
tmp_df = self.result_df.reset_index(inplace=False)
152+
tmp_df["_row_index"] = tmp_df.index
153+
tmp_df.set_index(TASK_NAME_KEY, inplace=True)
136154

137155
for col, val in agent_id:
138156
col = col.replace(".\n", ".")
139-
agent_df = agent_df[agent_df[col] == val]
140-
self.agent_df = agent_df
157+
tmp_df = tmp_df[tmp_df[col] == val]
158+
self.agent_df = tmp_df
141159

142160

143161
info = Info()
@@ -735,7 +753,7 @@ def dict_msg_to_markdown(d: dict):
735753
case _:
736754
parts.append(f"\n```\n{str(item)}\n```\n")
737755

738-
markdown = f"### {d["role"].capitalize()}\n"
756+
markdown = f"### {d['role'].capitalize()}\n"
739757
markdown += "\n".join(parts)
740758
return markdown
741759

@@ -1003,14 +1021,17 @@ def get_seeds_df(result_df: pd.DataFrame, task_name: str):
10031021
def extract_columns(row: pd.Series):
10041022
return pd.Series(
10051023
{
1006-
"seed": row[TASK_SEED_KEY],
1024+
"index": row.get("_row_index", None),
1025+
"seed": row.get(TASK_SEED_KEY, None),
10071026
"reward": row.get("cum_reward", None),
10081027
"err": bool(row.get("err_msg", None)),
10091028
"n_steps": row.get("n_steps", None),
10101029
}
10111030
)
10121031

10131032
seed_df = result_df.apply(extract_columns, axis=1)
1033+
# Ensure column order and readability
1034+
seed_df = seed_df[["seed", "reward", "err", "n_steps","index"]]
10141035
return seed_df
10151036

10161037

@@ -1028,15 +1049,26 @@ def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]):
10281049
def update_seeds(agent_task_id: tuple):
10291050
agent_id, task_name = agent_task_id
10301051
seed_df = get_seeds_df(info.agent_df, task_name)
1031-
first_seed = seed_df.iloc[0]["seed"]
1032-
return seed_df, EpisodeId(agent_id=agent_id, task_name=task_name, seed=first_seed)
1052+
first_seed = int(seed_df.iloc[0]["seed"]) if len(seed_df) else None
1053+
first_index = int(seed_df.iloc[0]["index"]) if len(seed_df) else None
1054+
return seed_df, EpisodeId(
1055+
agent_id=agent_id, task_name=task_name, seed=first_seed, row_index=first_index
1056+
)
10331057

10341058

10351059
def on_select_seed(evt: gr.SelectData, df: pd.DataFrame, agent_task_id: tuple):
10361060
agent_id, task_name = agent_task_id
10371061
col_idx = df.columns.get_loc("seed")
1038-
seed = evt.row_value[col_idx] # seed should be the first column
1039-
return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed)
1062+
idx_col = df.columns.get_loc("index") if "index" in df.columns else None
1063+
seed = evt.row_value[col_idx]
1064+
row_index = evt.row_value[idx_col] if idx_col is not None else None
1065+
try:
1066+
seed = int(seed)
1067+
if row_index is not None:
1068+
row_index = int(row_index)
1069+
except Exception:
1070+
pass
1071+
return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed, row_index=row_index)
10401072

10411073

10421074
def new_episode(episode_id: EpisodeId, progress=gr.Progress()):

0 commit comments

Comments
 (0)