Skip to content

Commit 1ed6d34

Browse files
fix: aggregate tasks repeats results by the task id
1 parent 757f6f4 commit 1ed6d34

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

src/rai_bench/rai_bench/base_benchmark.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class ModelSummary(BaseModel):
5555

5656
class TasksSummary(BaseModel):
5757
model_name: str = Field(..., description="Name of the LLM.")
58+
task_id: str = Field(..., description="Unique identifier for the task.")
59+
task_prompt: str = Field(
60+
..., description="The task prompt that identifies the task."
61+
)
5862
avg_success_rate: float = Field(
5963
..., description="Average result for task across all repeats."
6064
)
@@ -68,9 +72,7 @@ class TasksSummary(BaseModel):
6872
...,
6973
description="Standard deviation of the time taken across all repeats for one task.",
7074
)
71-
total_tasks: int = Field(
72-
..., description="Total number of executed tasks across all repeats per task."
73-
)
75+
repeats: int = Field(..., description="Total number of repeats for task.")
7476

7577

7678
class TimeoutException(Exception):

src/rai_bench/rai_bench/test_models.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -216,51 +216,52 @@ def merge_tasks_summary(bench_name: str, model_name: str, run_dir: Path) -> None
216216
if not model_dir.exists():
217217
return
218218

219-
# Collect all task results from all repeats
220-
task_data_by_prompt: Dict[str, Dict[str, List[float]]] = {}
219+
task_data_by_id: Dict[str, Dict[str, Any]] = {}
221220

222221
for repeat_dir in model_dir.iterdir():
223222
if repeat_dir.is_dir() and repeat_dir.name.isdigit():
224223
results_file = repeat_dir / DETAILED_FILE_NAME
225224
if results_file.exists():
226-
# Read detailed results from this repeat
227225
with open(results_file, "r") as f:
228226
reader = csv.DictReader(f)
229227
for row in reader:
228+
task_id = row["task_id"]
230229
task_prompt = row["task_prompt"]
231230
score = float(row["score"])
232231
total_time = float(row["total_time"])
233232

234-
if task_prompt not in task_data_by_prompt:
235-
task_data_by_prompt[task_prompt] = {
233+
if task_id not in task_data_by_id:
234+
task_data_by_id[task_id] = {
236235
"scores": [],
237236
"times": [],
237+
"task_prompt": task_prompt,
238238
}
239239

240-
task_data_by_prompt[task_prompt]["scores"].append(score)
241-
task_data_by_prompt[task_prompt]["times"].append(total_time)
240+
task_data_by_id[task_id]["scores"].append(score)
241+
task_data_by_id[task_id]["times"].append(total_time)
242242

243-
if not task_data_by_prompt:
243+
if not task_data_by_id:
244244
return
245245

246246
# Calculate statistics for each task
247247
task_summaries: List[TasksSummary] = []
248-
for task_prompt, data in task_data_by_prompt.items():
248+
for task_id, data in task_data_by_id.items():
249249
scores = np.array(data["scores"])
250250
times = np.array(data["times"])
251+
task_prompt = data["task_prompt"]
251252

252253
task_summary = TasksSummary(
253254
model_name=model_name,
255+
task_id=task_id,
254256
task_prompt=task_prompt,
255257
avg_success_rate=round(float(scores.mean()), 3),
256258
std_success_rate=round(float(scores.std()), 3),
257259
avg_time=round(float(times.mean()), 3),
258260
std_time=round(float(times.std()), 3),
259-
repeats=len(scores), # TODO (mkotynia) (extract repeats in another way)
261+
repeats=len(scores),
260262
)
261263
task_summaries.append(task_summary)
262264

263-
# Save task summaries to CSV
264265
tasks_summary_file = model_dir / TASKS_SUMMARY_FILE_NAME
265266
with open(tasks_summary_file, "w", newline="") as f:
266267
if task_summaries:
@@ -420,15 +421,17 @@ def test_models(
420421
bench_logger.critical(
421422
f"{bench_conf.name} benchmark for {model_name}, vendor: {vendors[i]}, execution number: {u + 1}"
422423
)
423-
# TODO (mkotynia): resolve unbound bench_logger
424-
bench_logger.info(f"Merging summaries for benchmark: {bench_conf.name}")
424+
merge_results_logger = define_benchmark_logger(out_dir=Path(out_dir))
425+
merge_results_logger.info(
426+
f"Merging summaries for benchmark: {bench_conf.name}"
427+
)
425428

426429
for model_name in model_names:
427430
merge_model_repeats_summary(bench_conf.name, model_name, run_dir)
428431
merge_tasks_summary(bench_conf.name, model_name, run_dir)
429432

430433
merge_benchmark_summary(bench_conf.name, run_dir, model_names)
431434

432-
bench_logger.info(
435+
merge_results_logger.info(
433436
f"Summary merging completed for benchmark: {bench_conf.name}"
434437
)

0 commit comments

Comments
 (0)