Skip to content

Commit 6a73918

Browse files
authored
Benchmark test high-level summary file (#4272)
high level summary of all models over all datasets in one sheet
1 parent 0d3cee1 commit 6a73918

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

tests/perf_v2/summary.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def summarize_task(raw_data: pd.DataFrame, task: OTXTaskType, output_root: Path)
249249

250250

251251
def task_high_level_summary(raw_data: pd.DataFrame, task: OTXTaskType, output_root: Path):
252-
"""Summarize high-level task performance."""
252+
"""Summarize high-level task performance over all datasets, with one row per model."""
253253

254254
raw_task_data = raw_data.query(f"task == '{task.value}'")
255255
if raw_task_data is None or len(raw_task_data) == 0:
@@ -264,18 +264,36 @@ def task_high_level_summary(raw_data: pd.DataFrame, task: OTXTaskType, output_ro
264264
raw_task_data.loc[:, col] = raw_task_data[col].astype(str) # Prevent strings like '2.0.0' being loaded as float
265265

266266
metrics = raw_task_data.select_dtypes(include=["number"]).columns.to_list()
267-
grouped_data = raw_task_data.groupby(["otx_version", "task"])
267+
268+
# Group by model instead of just otx_version and task
269+
grouped_data = raw_task_data.groupby(["otx_version", "task", "model"])
268270
aggregated = grouped_data.agg({metric: ["mean", "std"] for metric in metrics}).reset_index()
269271

270-
# Flatten the MultiIndex columns, excluding 'otx_version', 'task', 'model', 'data_group'
272+
# Flatten the MultiIndex columns
271273
cols_to_exclude = {"otx_version", "task", "model", "data"}
272274
aggregated.columns = [
273275
("_".join(col) if col[0] not in cols_to_exclude else col[0]) for col in aggregated.columns.to_numpy()
274276
]
275277

276-
# Get metrics (int/float) columns
277-
columns = aggregated.select_dtypes(include=["number"]).columns
278+
number_cols = aggregated.select_dtypes(include=["number"]).columns.to_list()
279+
meta_cols = aggregated.select_dtypes(include=["object"]).columns.to_list()
280+
for col in meta_cols:
281+
if col in ["model", "task", "otx_version"]:
282+
meta_cols.remove(col)
283+
284+
# Rearrange columns to match the order in aggregate function
285+
aggregated = aggregated.reindex(
286+
columns=[
287+
"otx_version",
288+
"task",
289+
"model",
290+
*number_cols,
291+
*meta_cols,
292+
],
293+
)
294+
278295
# Round all numeric columns to 4 decimal places
296+
columns = aggregated.select_dtypes(include=["number"]).columns
279297
for col in columns:
280298
aggregated[col] = aggregated[col].round(4)
281299

0 commit comments

Comments
 (0)