Skip to content

Commit a094271

Browse files
authored
Merge pull request #4264 from eugene123tw/eugene/rearrange-benchmark-summary-table
Refactor summary aggregation to drop 'seed' column and rearrange output columns
2 parents d88a8c6 + 470d946 commit a094271

File tree

1 file changed

+42
-5
lines changed

1 file changed

+42
-5
lines changed

tests/perf_v2/summary.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141

4242

4343
METADATA_ENTRIES = [
44-
"seed",
4544
"date",
4645
"task",
4746
"model",
@@ -118,12 +117,23 @@ def aggregate(raw_data: pd.DataFrame, metrics: list[str]) -> list[pd.DataFrame]:
118117
if raw_data is None or len(raw_data) == 0:
119118
return [pd.DataFrame()]
120119

120+
# Drop `seed` column if it is raw_data as we don't need to average it
121+
if "seed" in raw_data.columns:
122+
raw_data = raw_data.drop(columns=["seed"])
123+
121124
for col in METADATA_ENTRIES:
122125
raw_data.loc[:, col] = raw_data[col].astype(str) # Prevent strings like '2.0.0' being loaded as float
123126

124-
metrics = raw_data.select_dtypes(include=["number"]).columns.to_list()
125127
grouped_data = raw_data.groupby(
126-
["otx_version", "task", "model", "data", "test_branch", "test_commit", "data_group"],
128+
[
129+
"otx_version",
130+
"task",
131+
"model",
132+
"data",
133+
"test_branch",
134+
"test_commit",
135+
"data_group",
136+
],
127137
)
128138
aggregated = grouped_data.agg({metric: ["mean", "std"] for metric in metrics}).reset_index()
129139

@@ -133,6 +143,23 @@ def aggregate(raw_data: pd.DataFrame, metrics: list[str]) -> list[pd.DataFrame]:
133143
("_".join(col) if col[0] not in cols_to_exclude else col[0]) for col in aggregated.columns.to_numpy()
134144
]
135145

146+
# Get metrics (int/float) columns
147+
number_cols = aggregated.select_dtypes(include=["number"]).columns.to_list()
148+
149+
# Get metadata (str type) columns
150+
meta_cols = aggregated.select_dtypes(include=["object"]).columns.to_list()
151+
if "model" in meta_cols:
152+
meta_cols.remove("model")
153+
154+
rearrange_cols = [
155+
"model",
156+
*number_cols,
157+
*meta_cols,
158+
]
159+
160+
# Rearrange columns
161+
aggregated = aggregated.reindex(columns=rearrange_cols)
162+
136163
# Individualize each sheet by dataset
137164
dataset_dfs = []
138165
datase_names = aggregated["data"].unique()
@@ -186,9 +213,9 @@ def create_raw_dataset_xlsx(
186213
"""
187214
from tests.perf_v2 import CRITERIA_COLLECTIONS
188215

189-
col_names = []
190-
col_names.extend(METADATA_ENTRIES)
216+
col_names = ["seed"]
191217
col_names.extend([criterion.name for criterion in CRITERIA_COLLECTIONS[task]])
218+
col_names.extend(METADATA_ENTRIES)
192219

193220
raw_data_task = raw_data.query(f"task == '{task.value}'")
194221
for dataset in raw_data_task["data"].unique():
@@ -229,6 +256,10 @@ def task_high_level_summary(raw_data: pd.DataFrame, task: OTXTaskType, output_ro
229256
msg = f"No data found for task {task.value}"
230257
raise ValueError(msg)
231258

259+
# Drop `seed` column if it is raw_data as we don't need to average it
260+
if "seed" in raw_task_data.columns:
261+
raw_task_data = raw_task_data.drop(columns=["seed"])
262+
232263
for col in METADATA_ENTRIES:
233264
raw_task_data.loc[:, col] = raw_task_data[col].astype(str) # Prevent strings like '2.0.0' being loaded as float
234265

@@ -242,6 +273,12 @@ def task_high_level_summary(raw_data: pd.DataFrame, task: OTXTaskType, output_ro
242273
("_".join(col) if col[0] not in cols_to_exclude else col[0]) for col in aggregated.columns.to_numpy()
243274
]
244275

276+
# Get metrics (int/float) columns
277+
columns = aggregated.select_dtypes(include=["number"]).columns
278+
# Round all numeric columns to 4 decimal places
279+
for col in columns:
280+
aggregated[col] = aggregated[col].round(4)
281+
245282
# Save the high-level summary data to an Excel file
246283
task_high_level_summary_xlsx_path = output_root / f"{task.value}-high-level-summary.xlsx"
247284
aggregated.to_excel(task_high_level_summary_xlsx_path, index=False)

0 commit comments

Comments
 (0)