@@ -101,42 +101,9 @@ def update_exp_result(self, episode_id: EpisodeId):
101101 if self .result_df is None or episode_id .task_name is None or episode_id .seed is None :
102102 self .exp_result = None
103103
104- # Prefer selecting by explicit row index if available
105- if episode_id .row_index is not None :
106- tmp_df = self .result_df .reset_index (inplace = False )
107- tmp_df ["_row_index" ] = tmp_df .index
108- sub_df = tmp_df [tmp_df ["_row_index" ] == episode_id .row_index ]
109- if len (sub_df ) == 0 :
110- self .exp_result = None
111- raise ValueError (f"Could not find episode for row_index: { episode_id .row_index } " )
112- if len (sub_df ) > 1 :
113- warning (
114- f"Found multiple rows for row_index: { episode_id .row_index } . Using the first one."
115- )
116- exp_dir = sub_df .iloc [0 ]["exp_dir" ]
117- print (exp_dir )
118- self .exp_result = ExpResult (exp_dir )
119- self .step = 0
120- return
121-
122- # find unique row for task_name and seed
104+ # find unique row using idx
123105 result_df = self .agent_df .reset_index (inplace = False )
124- sub_df = result_df [
125- (result_df [TASK_NAME_KEY ] == episode_id .task_name )
126- & (result_df [TASK_SEED_KEY ] == episode_id .seed )
127- ]
128- if len (sub_df ) == 0 :
129- self .exp_result = None
130- raise ValueError (
131- f"Could not find task_name: { episode_id .task_name } and seed: { episode_id .seed } "
132- )
133-
134- if len (sub_df ) > 1 :
135- warning (
136- f"Found multiple rows for task_name: { episode_id .task_name } and seed: { episode_id .seed } . Using the first one."
137- )
138-
139- exp_dir = sub_df .iloc [0 ]["exp_dir" ]
106+ exp_dir = result_df .iloc [episode_id .row_index ]["exp_dir" ]
140107 print (exp_dir )
141108 self .exp_result = ExpResult (exp_dir )
142109 self .step = 0
@@ -1022,7 +989,7 @@ def get_seeds_df(result_df: pd.DataFrame, task_name: str):
1022989 def extract_columns (row : pd .Series ):
1023990 return pd .Series (
1024991 {
1025- "index " : row .get ("_row_index" , None ),
992+ "idx " : row .get ("_row_index" , None ),
1026993 "seed" : row .get (TASK_SEED_KEY , None ),
1027994 "reward" : row .get ("cum_reward" , None ),
1028995 "err" : bool (row .get ("err_msg" , None )),
@@ -1032,7 +999,7 @@ def extract_columns(row: pd.Series):
1032999
10331000 seed_df = result_df .apply (extract_columns , axis = 1 )
10341001 # Ensure column order and readability
1035- seed_df = seed_df [["seed" , "reward" , "err" , "n_steps" ,"index " ]]
1002+ seed_df = seed_df [["seed" , "reward" , "err" , "n_steps" , "idx " ]]
10361003 return seed_df
10371004
10381005
@@ -1050,8 +1017,8 @@ def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]):
10501017def update_seeds (agent_task_id : tuple ):
10511018 agent_id , task_name = agent_task_id
10521019 seed_df = get_seeds_df (info .agent_df , task_name )
1053- first_seed = int (seed_df .iloc [0 ]["seed" ]) if len ( seed_df ) else None
1054- first_index = int (seed_df .iloc [0 ]["index " ]) if len ( seed_df ) else None
1020+ first_seed = int (seed_df .iloc [0 ]["seed" ])
1021+ first_index = int (seed_df .iloc [0 ]["idx " ])
10551022 return seed_df , EpisodeId (
10561023 agent_id = agent_id , task_name = task_name , seed = first_seed , row_index = first_index
10571024 )
@@ -1060,15 +1027,9 @@ def update_seeds(agent_task_id: tuple):
10601027def on_select_seed (evt : gr .SelectData , df : pd .DataFrame , agent_task_id : tuple ):
10611028 agent_id , task_name = agent_task_id
10621029 col_idx = df .columns .get_loc ("seed" )
1063- idx_col = df .columns .get_loc ("index" ) if "index" in df . columns else None
1030+ idx_col = df .columns .get_loc ("idx" )
10641031 seed = evt .row_value [col_idx ]
1065- row_index = evt .row_value [idx_col ] if idx_col is not None else None
1066- try :
1067- seed = int (seed )
1068- if row_index is not None :
1069- row_index = int (row_index )
1070- except Exception :
1071- pass
1032+ row_index = evt .row_value [idx_col ]
10721033 return EpisodeId (agent_id = agent_id , task_name = task_name , seed = seed , row_index = row_index )
10731034
10741035
@@ -1167,7 +1128,7 @@ def new_exp_dir(study_names: list, progress=gr.Progress(), just_refresh=False):
11671128 study_names .remove (select_dir_instructions )
11681129
11691130 if len (study_names ) == 0 :
1170- return None , None
1131+ return None , None , None , None , None , None
11711132
11721133 info .study_dirs = [info .results_dir / study_name .split (" - " )[0 ] for study_name in study_names ]
11731134 info .result_df = inspect_results .load_result_df (info .study_dirs , progress_fn = progress .tqdm )
0 commit comments