@@ -249,7 +249,7 @@ def summarize_task(raw_data: pd.DataFrame, task: OTXTaskType, output_root: Path)
249249
250250
251251def task_high_level_summary (raw_data : pd .DataFrame , task : OTXTaskType , output_root : Path ):
252- """Summarize high-level task performance."""
252+ """Summarize high-level task performance over all datasets, with one row per model ."""
253253
254254 raw_task_data = raw_data .query (f"task == '{ task .value } '" )
255255 if raw_task_data is None or len (raw_task_data ) == 0 :
@@ -264,18 +264,36 @@ def task_high_level_summary(raw_data: pd.DataFrame, task: OTXTaskType, output_ro
264264 raw_task_data .loc [:, col ] = raw_task_data [col ].astype (str ) # Prevent strings like '2.0.0' being loaded as float
265265
266266 metrics = raw_task_data .select_dtypes (include = ["number" ]).columns .to_list ()
267- grouped_data = raw_task_data .groupby (["otx_version" , "task" ])
267+
268+ # Group by model instead of just otx_version and task
269+ grouped_data = raw_task_data .groupby (["otx_version" , "task" , "model" ])
268270 aggregated = grouped_data .agg ({metric : ["mean" , "std" ] for metric in metrics }).reset_index ()
269271
270- # Flatten the MultiIndex columns, excluding 'otx_version', 'task', 'model', 'data_group'
272+ # Flatten the MultiIndex columns
271273 cols_to_exclude = {"otx_version" , "task" , "model" , "data" }
272274 aggregated .columns = [
273275 ("_" .join (col ) if col [0 ] not in cols_to_exclude else col [0 ]) for col in aggregated .columns .to_numpy ()
274276 ]
275277
276- # Get metrics (int/float) columns
277- columns = aggregated .select_dtypes (include = ["number" ]).columns
278+ number_cols = aggregated .select_dtypes (include = ["number" ]).columns .to_list ()
279+ meta_cols = aggregated .select_dtypes (include = ["object" ]).columns .to_list ()
280+ for col in meta_cols :
281+ if col in ["model" , "task" , "otx_version" ]:
282+ meta_cols .remove (col )
283+
284+ # Rearrange columns to match the order in aggregate function
285+ aggregated = aggregated .reindex (
286+ columns = [
287+ "otx_version" ,
288+ "task" ,
289+ "model" ,
290+ * number_cols ,
291+ * meta_cols ,
292+ ],
293+ )
294+
278295 # Round all numeric columns to 4 decimal places
296+ columns = aggregated .select_dtypes (include = ["number" ]).columns
279297 for col in columns :
280298 aggregated [col ] = aggregated [col ].round (4 )
281299
0 commit comments