Skip to content

Commit 466ec59

Browse files
author
Magdalena Kotynia
committed
feat: added merging results summaries across all repeats and all models
1 parent 02b459b commit 466ec59

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

src/rai_bench/rai_bench/test_models.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,32 @@
1111
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# # See the License for the specific language governing permissions and
1313
# # limitations under the License.
14+
import csv
1415
import uuid
1516
from abc import abstractmethod
1617
from datetime import datetime
1718
from pathlib import Path
1819
from typing import Any, Dict, List, Literal
1920

21+
import numpy as np
2022
from git import Optional
2123
from langchain.chat_models.base import BaseChatModel
2224
from pydantic import BaseModel
2325

2426
import rai_bench.manipulation_o3de as manipulation_o3de
2527
import rai_bench.tool_calling_agent as tool_calling_agent
2628
import rai_bench.vlm_benchmark as vlm_benchmark
29+
from rai_bench.base_benchmark import RunSummary
30+
from rai_bench.results_processing.data_loading import SUMMARY_FILE_NAME
2731
from rai_bench.utils import (
2832
define_benchmark_logger,
2933
get_llm_for_benchmark,
3034
get_llm_model_name,
3135
)
3236

37+
REPEATS_SUMMARY_FILE_NAME = "repeats_summary.csv"
38+
BENCHMARK_SUMMARY = "benchmark_summary.csv"
39+
3340

3441
class BenchmarkConfig(BaseModel):
3542
repeats: int = 1
@@ -97,6 +104,98 @@ def name(self) -> str:
97104
return "vlm"
98105

99106

107+
def merge_model_repeats_summary(
108+
bench_name: str, model_name: str, run_dir: Path
109+
) -> None:
110+
"""Merge summary results across all repeats for a single model.
111+
112+
Parameters
113+
----------
114+
bench_name : str
115+
Name of the benchmark
116+
model_name : str
117+
Name of the model
118+
run_dir : Path
119+
Directory containing the benchmark run results
120+
"""
121+
model_dir = run_dir / bench_name / model_name
122+
if not model_dir.exists():
123+
return
124+
125+
# TODO (mkotynia): create new BenchSummary model with added std of success rate and time across repeats
126+
summaries: List[RunSummary] = []
127+
for repeat_dir in model_dir.iterdir():
128+
if repeat_dir.is_dir() and repeat_dir.name.isdigit():
129+
summary_file = repeat_dir / SUMMARY_FILE_NAME
130+
if summary_file.exists():
131+
with open(summary_file, "r") as f:
132+
reader = csv.DictReader(f)
133+
for row in reader:
134+
summaries.append(RunSummary.model_validate(row))
135+
136+
if not summaries:
137+
return
138+
139+
avg_success_rate = np.mean([s.success_rate for s in summaries])
140+
avg_time = np.mean([s.avg_time for s in summaries])
141+
total_tasks = np.min(
142+
[s.total_tasks for s in summaries]
143+
) # NOTE (mkotynia) get the minimum total tasks across repeats. If benchmark breaks for some repeat, it will be noticed in such case
144+
145+
merged_summary = RunSummary(
146+
model_name=model_name,
147+
success_rate=round(float(avg_success_rate), 2),
148+
avg_time=round(float(avg_time), 3),
149+
total_tasks=total_tasks,
150+
)
151+
152+
merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
153+
with open(merged_file, "w", newline="") as f:
154+
writer = csv.DictWriter(f, fieldnames=RunSummary.model_fields.keys())
155+
writer.writeheader()
156+
writer.writerow(merged_summary.model_dump())
157+
158+
159+
def merge_benchmark_summary(
160+
bench_name: str, run_dir: Path, model_names: List[str]
161+
) -> None:
162+
"""Merge summary results across all models for a single benchmark.
163+
164+
Parameters
165+
----------
166+
bench_name : str
167+
Name of the benchmark
168+
run_dir : Path
169+
Directory containing the benchmark run results
170+
model_names : List[str]
171+
List of model names to include in the summary
172+
"""
173+
bench_dir = run_dir / bench_name
174+
if not bench_dir.exists():
175+
return
176+
177+
all_summaries: List[RunSummary] = []
178+
for model_name in model_names:
179+
model_dir = bench_dir / model_name
180+
merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
181+
182+
if merged_file.exists():
183+
with open(merged_file, "r") as f:
184+
reader = csv.DictReader(f)
185+
for row in reader:
186+
all_summaries.append(RunSummary.model_validate(row))
187+
188+
if not all_summaries:
189+
return
190+
191+
benchmark_summary_file = bench_dir / BENCHMARK_SUMMARY
192+
with open(benchmark_summary_file, "w", newline="") as f:
193+
writer = csv.DictWriter(f, fieldnames=RunSummary.model_fields.keys())
194+
writer.writeheader()
195+
for summary in all_summaries:
196+
writer.writerow(summary.model_dump())
197+
198+
100199
def test_dual_agents(
101200
multimodal_llms: List[BaseChatModel],
102201
tool_calling_models: List[BaseChatModel],
@@ -183,6 +282,7 @@ def test_models(
183282
# for each bench configuration seperate run folder
184283
now = datetime.now()
185284
run_name = f"run_{now.strftime('%Y-%m-%d_%H-%M-%S')}"
285+
run_dir = Path(out_dir) / run_name
186286
for i, model_name in enumerate(model_names):
187287
for u in range(bench_conf.repeats):
188288
curr_out_dir = (
@@ -240,8 +340,20 @@ def test_models(
240340
tasks=vlm_tasks,
241341
bench_logger=bench_logger,
242342
)
343+
243344
except Exception as e:
244345
bench_logger.critical(f"BENCHMARK RUN FAILED: {e}")
245346
bench_logger.critical(
246347
f"{bench_conf.name} benchmark for {model_name}, vendor: {vendors[i]}, execution number: {u + 1}"
247348
)
349+
# TODO (mkotynia): resolve unbound bench_logger
350+
bench_logger.info(f"Merging summaries for benchmark: {bench_conf.name}")
351+
352+
for model_name in model_names:
353+
merge_model_repeats_summary(bench_conf.name, model_name, run_dir)
354+
355+
merge_benchmark_summary(bench_conf.name, run_dir, model_names)
356+
357+
bench_logger.info(
358+
f"Summary merging completed for benchmark: {bench_conf.name}"
359+
)

0 commit comments

Comments
 (0)