@@ -74,6 +74,7 @@ class EpisodeId:
7474 agent_id : str = None
7575 task_name : str = None
7676 seed : int = None
77+ row_index : int = None # unique row index to disambiguate selections
7778
7879
7980@dataclass
@@ -99,6 +100,24 @@ def update_exp_result(self, episode_id: EpisodeId):
99100 if self .result_df is None or episode_id .task_name is None or episode_id .seed is None :
100101 self .exp_result = None
101102
103+ # Prefer selecting by explicit row index if available
104+ if episode_id .row_index is not None :
105+ tmp_df = self .result_df .reset_index (inplace = False )
106+ tmp_df ["_row_index" ] = tmp_df .index
107+ sub_df = tmp_df [tmp_df ["_row_index" ] == episode_id .row_index ]
108+ if len (sub_df ) == 0 :
109+ self .exp_result = None
110+ raise ValueError (f"Could not find episode for row_index: { episode_id .row_index } " )
111+ if len (sub_df ) > 1 :
112+ warning (
113+ f"Found multiple rows for row_index: { episode_id .row_index } . Using the first one."
114+ )
115+ exp_dir = sub_df .iloc [0 ]["exp_dir" ]
116+ print (exp_dir )
117+ self .exp_result = ExpResult (exp_dir )
118+ self .step = 0
119+ return
120+
102121 # find unique row for task_name and seed
103122 result_df = self .agent_df .reset_index (inplace = False )
104123 sub_df = result_df [
@@ -128,16 +147,15 @@ def get_agent_id(self, row: pd.Series):
128147 return agent_id
129148
130149 def filter_agent_id (self , agent_id : list [tuple ]):
131- # query_str = " & ".join([f"`{col}` == {repr(val)}" for col, val in agent_id])
132- # agent_df = info.result_df.query(query_str)
133-
134- agent_df = self .result_df .reset_index (inplace = False )
135- agent_df .set_index (TASK_NAME_KEY , inplace = True )
150+ # Preserve a stable row index to disambiguate selections later
151+ tmp_df = self .result_df .reset_index (inplace = False )
152+ tmp_df ["_row_index" ] = tmp_df .index
153+ tmp_df .set_index (TASK_NAME_KEY , inplace = True )
136154
137155 for col , val in agent_id :
138156 col = col .replace (".\n " , "." )
139- agent_df = agent_df [ agent_df [col ] == val ]
140- self .agent_df = agent_df
157+ tmp_df = tmp_df [ tmp_df [col ] == val ]
158+ self .agent_df = tmp_df
141159
142160
143161info = Info ()
@@ -735,7 +753,7 @@ def dict_msg_to_markdown(d: dict):
735753 case _:
736754 parts .append (f"\n ```\n { str (item )} \n ```\n " )
737755
738- markdown = f"### { d [" role" ].capitalize ()} \n "
756+ markdown = f"### { d [' role' ].capitalize ()} \n "
739757 markdown += "\n " .join (parts )
740758 return markdown
741759
@@ -1003,14 +1021,17 @@ def get_seeds_df(result_df: pd.DataFrame, task_name: str):
10031021 def extract_columns (row : pd .Series ):
10041022 return pd .Series (
10051023 {
1006- "seed" : row [TASK_SEED_KEY ],
1024+ "index" : row .get ("_row_index" , None ),
1025+ "seed" : row .get (TASK_SEED_KEY , None ),
10071026 "reward" : row .get ("cum_reward" , None ),
10081027 "err" : bool (row .get ("err_msg" , None )),
10091028 "n_steps" : row .get ("n_steps" , None ),
10101029 }
10111030 )
10121031
10131032 seed_df = result_df .apply (extract_columns , axis = 1 )
1033+ # Ensure column order and readability
1034+ seed_df = seed_df [["seed" , "reward" , "err" , "n_steps" ,"index" ]]
10141035 return seed_df
10151036
10161037
@@ -1028,15 +1049,26 @@ def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]):
10281049def update_seeds (agent_task_id : tuple ):
10291050 agent_id , task_name = agent_task_id
10301051 seed_df = get_seeds_df (info .agent_df , task_name )
1031- first_seed = seed_df .iloc [0 ]["seed" ]
1032- return seed_df , EpisodeId (agent_id = agent_id , task_name = task_name , seed = first_seed )
1052+ first_seed = int (seed_df .iloc [0 ]["seed" ]) if len (seed_df ) else None
1053+ first_index = int (seed_df .iloc [0 ]["index" ]) if len (seed_df ) else None
1054+ return seed_df , EpisodeId (
1055+ agent_id = agent_id , task_name = task_name , seed = first_seed , row_index = first_index
1056+ )
10331057
10341058
10351059def on_select_seed (evt : gr .SelectData , df : pd .DataFrame , agent_task_id : tuple ):
10361060 agent_id , task_name = agent_task_id
10371061 col_idx = df .columns .get_loc ("seed" )
1038- seed = evt .row_value [col_idx ] # seed should be the first column
1039- return EpisodeId (agent_id = agent_id , task_name = task_name , seed = seed )
1062+ idx_col = df .columns .get_loc ("index" ) if "index" in df .columns else None
1063+ seed = evt .row_value [col_idx ]
1064+ row_index = evt .row_value [idx_col ] if idx_col is not None else None
1065+ try :
1066+ seed = int (seed )
1067+ if row_index is not None :
1068+ row_index = int (row_index )
1069+ except Exception :
1070+ pass
1071+ return EpisodeId (agent_id = agent_id , task_name = task_name , seed = seed , row_index = row_index )
10401072
10411073
10421074def new_episode (episode_id : EpisodeId , progress = gr .Progress ()):
0 commit comments