4141
4242
4343METADATA_ENTRIES = [
44- "seed" ,
4544 "date" ,
4645 "task" ,
4746 "model" ,
@@ -118,12 +117,23 @@ def aggregate(raw_data: pd.DataFrame, metrics: list[str]) -> list[pd.DataFrame]:
118117 if raw_data is None or len (raw_data ) == 0 :
119118 return [pd .DataFrame ()]
120119
120+ # Drop `seed` column if it is raw_data as we don't need to average it
121+ if "seed" in raw_data .columns :
122+ raw_data = raw_data .drop (columns = ["seed" ])
123+
121124 for col in METADATA_ENTRIES :
122125 raw_data .loc [:, col ] = raw_data [col ].astype (str ) # Prevent strings like '2.0.0' being loaded as float
123126
124- metrics = raw_data .select_dtypes (include = ["number" ]).columns .to_list ()
125127 grouped_data = raw_data .groupby (
126- ["otx_version" , "task" , "model" , "data" , "test_branch" , "test_commit" , "data_group" ],
128+ [
129+ "otx_version" ,
130+ "task" ,
131+ "model" ,
132+ "data" ,
133+ "test_branch" ,
134+ "test_commit" ,
135+ "data_group" ,
136+ ],
127137 )
128138 aggregated = grouped_data .agg ({metric : ["mean" , "std" ] for metric in metrics }).reset_index ()
129139
@@ -133,6 +143,23 @@ def aggregate(raw_data: pd.DataFrame, metrics: list[str]) -> list[pd.DataFrame]:
133143 ("_" .join (col ) if col [0 ] not in cols_to_exclude else col [0 ]) for col in aggregated .columns .to_numpy ()
134144 ]
135145
146+ # Get metrics (int/float) columns
147+ number_cols = aggregated .select_dtypes (include = ["number" ]).columns .to_list ()
148+
149+ # Get metadata (str type) columns
150+ meta_cols = aggregated .select_dtypes (include = ["object" ]).columns .to_list ()
151+ if "model" in meta_cols :
152+ meta_cols .remove ("model" )
153+
154+ rearrange_cols = [
155+ "model" ,
156+ * number_cols ,
157+ * meta_cols ,
158+ ]
159+
160+ # Rearrange columns
161+ aggregated = aggregated .reindex (columns = rearrange_cols )
162+
136163 # Individualize each sheet by dataset
137164 dataset_dfs = []
138165 datase_names = aggregated ["data" ].unique ()
@@ -186,9 +213,9 @@ def create_raw_dataset_xlsx(
186213 """
187214 from tests .perf_v2 import CRITERIA_COLLECTIONS
188215
189- col_names = []
190- col_names .extend (METADATA_ENTRIES )
216+ col_names = ["seed" ]
191217 col_names .extend ([criterion .name for criterion in CRITERIA_COLLECTIONS [task ]])
218+ col_names .extend (METADATA_ENTRIES )
192219
193220 raw_data_task = raw_data .query (f"task == '{ task .value } '" )
194221 for dataset in raw_data_task ["data" ].unique ():
@@ -229,6 +256,10 @@ def task_high_level_summary(raw_data: pd.DataFrame, task: OTXTaskType, output_ro
229256 msg = f"No data found for task { task .value } "
230257 raise ValueError (msg )
231258
259+ # Drop `seed` column if it is raw_data as we don't need to average it
260+ if "seed" in raw_task_data .columns :
261+ raw_task_data = raw_task_data .drop (columns = ["seed" ])
262+
232263 for col in METADATA_ENTRIES :
233264 raw_task_data .loc [:, col ] = raw_task_data [col ].astype (str ) # Prevent strings like '2.0.0' being loaded as float
234265
@@ -242,6 +273,12 @@ def task_high_level_summary(raw_data: pd.DataFrame, task: OTXTaskType, output_ro
242273 ("_" .join (col ) if col [0 ] not in cols_to_exclude else col [0 ]) for col in aggregated .columns .to_numpy ()
243274 ]
244275
276+ # Get metrics (int/float) columns
277+ columns = aggregated .select_dtypes (include = ["number" ]).columns
278+ # Round all numeric columns to 4 decimal places
279+ for col in columns :
280+ aggregated [col ] = aggregated [col ].round (4 )
281+
245282 # Save the high-level summary data to an Excel file
246283 task_high_level_summary_xlsx_path = output_root / f"{ task .value } -high-level-summary.xlsx"
247284 aggregated .to_excel (task_high_level_summary_xlsx_path , index = False )
0 commit comments