Skip to content

Commit f2e7b40

Browse files
clean up xray-fixes
1 parent 9f3d706 commit f2e7b40

File tree

1 file changed

+9
-48
lines changed

1 file changed

+9
-48
lines changed

src/agentlab/analyze/agent_xray.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -101,42 +101,9 @@ def update_exp_result(self, episode_id: EpisodeId):
101101
if self.result_df is None or episode_id.task_name is None or episode_id.seed is None:
102102
self.exp_result = None
103103

104-
# Prefer selecting by explicit row index if available
105-
if episode_id.row_index is not None:
106-
tmp_df = self.result_df.reset_index(inplace=False)
107-
tmp_df["_row_index"] = tmp_df.index
108-
sub_df = tmp_df[tmp_df["_row_index"] == episode_id.row_index]
109-
if len(sub_df) == 0:
110-
self.exp_result = None
111-
raise ValueError(f"Could not find episode for row_index: {episode_id.row_index}")
112-
if len(sub_df) > 1:
113-
warning(
114-
f"Found multiple rows for row_index: {episode_id.row_index}. Using the first one."
115-
)
116-
exp_dir = sub_df.iloc[0]["exp_dir"]
117-
print(exp_dir)
118-
self.exp_result = ExpResult(exp_dir)
119-
self.step = 0
120-
return
121-
122-
# find unique row for task_name and seed
104+
# find unique row using idx
123105
result_df = self.agent_df.reset_index(inplace=False)
124-
sub_df = result_df[
125-
(result_df[TASK_NAME_KEY] == episode_id.task_name)
126-
& (result_df[TASK_SEED_KEY] == episode_id.seed)
127-
]
128-
if len(sub_df) == 0:
129-
self.exp_result = None
130-
raise ValueError(
131-
f"Could not find task_name: {episode_id.task_name} and seed: {episode_id.seed}"
132-
)
133-
134-
if len(sub_df) > 1:
135-
warning(
136-
f"Found multiple rows for task_name: {episode_id.task_name} and seed: {episode_id.seed}. Using the first one."
137-
)
138-
139-
exp_dir = sub_df.iloc[0]["exp_dir"]
106+
exp_dir = result_df.iloc[episode_id.row_index]["exp_dir"]
140107
print(exp_dir)
141108
self.exp_result = ExpResult(exp_dir)
142109
self.step = 0
@@ -1022,7 +989,7 @@ def get_seeds_df(result_df: pd.DataFrame, task_name: str):
1022989
def extract_columns(row: pd.Series):
1023990
return pd.Series(
1024991
{
1025-
"index": row.get("_row_index", None),
992+
"idx": row.get("_row_index", None),
1026993
"seed": row.get(TASK_SEED_KEY, None),
1027994
"reward": row.get("cum_reward", None),
1028995
"err": bool(row.get("err_msg", None)),
@@ -1032,7 +999,7 @@ def extract_columns(row: pd.Series):
1032999

10331000
seed_df = result_df.apply(extract_columns, axis=1)
10341001
# Ensure column order and readability
1035-
seed_df = seed_df[["seed", "reward", "err", "n_steps","index"]]
1002+
seed_df = seed_df[["seed", "reward", "err", "n_steps", "idx"]]
10361003
return seed_df
10371004

10381005

@@ -1050,8 +1017,8 @@ def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]):
10501017
def update_seeds(agent_task_id: tuple):
10511018
agent_id, task_name = agent_task_id
10521019
seed_df = get_seeds_df(info.agent_df, task_name)
1053-
first_seed = int(seed_df.iloc[0]["seed"]) if len(seed_df) else None
1054-
first_index = int(seed_df.iloc[0]["index"]) if len(seed_df) else None
1020+
first_seed = int(seed_df.iloc[0]["seed"])
1021+
first_index = int(seed_df.iloc[0]["idx"])
10551022
return seed_df, EpisodeId(
10561023
agent_id=agent_id, task_name=task_name, seed=first_seed, row_index=first_index
10571024
)
@@ -1060,15 +1027,9 @@ def update_seeds(agent_task_id: tuple):
10601027
def on_select_seed(evt: gr.SelectData, df: pd.DataFrame, agent_task_id: tuple):
10611028
agent_id, task_name = agent_task_id
10621029
col_idx = df.columns.get_loc("seed")
1063-
idx_col = df.columns.get_loc("index") if "index" in df.columns else None
1030+
idx_col = df.columns.get_loc("idx")
10641031
seed = evt.row_value[col_idx]
1065-
row_index = evt.row_value[idx_col] if idx_col is not None else None
1066-
try:
1067-
seed = int(seed)
1068-
if row_index is not None:
1069-
row_index = int(row_index)
1070-
except Exception:
1071-
pass
1032+
row_index = evt.row_value[idx_col]
10721033
return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed, row_index=row_index)
10731034

10741035

@@ -1167,7 +1128,7 @@ def new_exp_dir(study_names: list, progress=gr.Progress(), just_refresh=False):
11671128
study_names.remove(select_dir_instructions)
11681129

11691130
if len(study_names) == 0:
1170-
return None, None
1131+
return None, None, None, None, None, None
11711132

11721133
info.study_dirs = [info.results_dir / study_name.split(" - ")[0] for study_name in study_names]
11731134
info.result_df = inspect_results.load_result_df(info.study_dirs, progress_fn=progress.tqdm)

0 commit comments

Comments
 (0)