|
11 | 11 | # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # # See the License for the specific language governing permissions and
|
13 | 13 | # # limitations under the License.
|
| 14 | +import csv |
14 | 15 | import uuid
|
15 | 16 | from abc import abstractmethod
|
16 | 17 | from datetime import datetime
|
17 | 18 | from pathlib import Path
|
18 | 19 | from typing import Any, Dict, List, Literal
|
19 | 20 |
|
| 21 | +import numpy as np |
20 | 22 | from git import Optional
|
21 | 23 | from langchain.chat_models.base import BaseChatModel
|
22 | 24 | from pydantic import BaseModel
|
23 | 25 |
|
24 | 26 | import rai_bench.manipulation_o3de as manipulation_o3de
|
25 | 27 | import rai_bench.tool_calling_agent as tool_calling_agent
|
26 | 28 | import rai_bench.vlm_benchmark as vlm_benchmark
|
| 29 | +from rai_bench.base_benchmark import RunSummary |
| 30 | +from rai_bench.results_processing.data_loading import SUMMARY_FILE_NAME |
27 | 31 | from rai_bench.utils import (
|
28 | 32 | define_benchmark_logger,
|
29 | 33 | get_llm_for_benchmark,
|
30 | 34 | get_llm_model_name,
|
31 | 35 | )
|
32 | 36 |
|
| 37 | +REPEATS_SUMMARY_FILE_NAME = "repeats_summary.csv" |
| 38 | +BENCHMARK_SUMMARY = "benchmark_summary.csv" |
| 39 | + |
33 | 40 |
|
34 | 41 | class BenchmarkConfig(BaseModel):
|
35 | 42 | repeats: int = 1
|
@@ -97,6 +104,98 @@ def name(self) -> str:
|
97 | 104 | return "vlm"
|
98 | 105 |
|
99 | 106 |
|
| 107 | +def merge_model_repeats_summary( |
| 108 | + bench_name: str, model_name: str, run_dir: Path |
| 109 | +) -> None: |
| 110 | + """Merge summary results across all repeats for a single model. |
| 111 | +
|
| 112 | + Parameters |
| 113 | + ---------- |
| 114 | + bench_name : str |
| 115 | + Name of the benchmark |
| 116 | + model_name : str |
| 117 | + Name of the model |
| 118 | + run_dir : Path |
| 119 | + Directory containing the benchmark run results |
| 120 | + """ |
| 121 | + model_dir = run_dir / bench_name / model_name |
| 122 | + if not model_dir.exists(): |
| 123 | + return |
| 124 | + |
| 125 | + # TODO (mkotynia): create new BenchSummary model with added std of success rate and time across repeats |
| 126 | + summaries: List[RunSummary] = [] |
| 127 | + for repeat_dir in model_dir.iterdir(): |
| 128 | + if repeat_dir.is_dir() and repeat_dir.name.isdigit(): |
| 129 | + summary_file = repeat_dir / SUMMARY_FILE_NAME |
| 130 | + if summary_file.exists(): |
| 131 | + with open(summary_file, "r") as f: |
| 132 | + reader = csv.DictReader(f) |
| 133 | + for row in reader: |
| 134 | + summaries.append(RunSummary.model_validate(row)) |
| 135 | + |
| 136 | + if not summaries: |
| 137 | + return |
| 138 | + |
| 139 | + avg_success_rate = np.mean([s.success_rate for s in summaries]) |
| 140 | + avg_time = np.mean([s.avg_time for s in summaries]) |
| 141 | + total_tasks = np.min( |
| 142 | + [s.total_tasks for s in summaries] |
| 143 | + ) # NOTE (mkotynia) get the minimum total tasks across repeats. If benchmark breaks for some repeat, it will be noticed in such case |
| 144 | + |
| 145 | + merged_summary = RunSummary( |
| 146 | + model_name=model_name, |
| 147 | + success_rate=round(float(avg_success_rate), 2), |
| 148 | + avg_time=round(float(avg_time), 3), |
| 149 | + total_tasks=total_tasks, |
| 150 | + ) |
| 151 | + |
| 152 | + merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME |
| 153 | + with open(merged_file, "w", newline="") as f: |
| 154 | + writer = csv.DictWriter(f, fieldnames=RunSummary.model_fields.keys()) |
| 155 | + writer.writeheader() |
| 156 | + writer.writerow(merged_summary.model_dump()) |
| 157 | + |
| 158 | + |
| 159 | +def merge_benchmark_summary( |
| 160 | + bench_name: str, run_dir: Path, model_names: List[str] |
| 161 | +) -> None: |
| 162 | + """Merge summary results across all models for a single benchmark. |
| 163 | +
|
| 164 | + Parameters |
| 165 | + ---------- |
| 166 | + bench_name : str |
| 167 | + Name of the benchmark |
| 168 | + run_dir : Path |
| 169 | + Directory containing the benchmark run results |
| 170 | + model_names : List[str] |
| 171 | + List of model names to include in the summary |
| 172 | + """ |
| 173 | + bench_dir = run_dir / bench_name |
| 174 | + if not bench_dir.exists(): |
| 175 | + return |
| 176 | + |
| 177 | + all_summaries: List[RunSummary] = [] |
| 178 | + for model_name in model_names: |
| 179 | + model_dir = bench_dir / model_name |
| 180 | + merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME |
| 181 | + |
| 182 | + if merged_file.exists(): |
| 183 | + with open(merged_file, "r") as f: |
| 184 | + reader = csv.DictReader(f) |
| 185 | + for row in reader: |
| 186 | + all_summaries.append(RunSummary.model_validate(row)) |
| 187 | + |
| 188 | + if not all_summaries: |
| 189 | + return |
| 190 | + |
| 191 | + benchmark_summary_file = bench_dir / BENCHMARK_SUMMARY |
| 192 | + with open(benchmark_summary_file, "w", newline="") as f: |
| 193 | + writer = csv.DictWriter(f, fieldnames=RunSummary.model_fields.keys()) |
| 194 | + writer.writeheader() |
| 195 | + for summary in all_summaries: |
| 196 | + writer.writerow(summary.model_dump()) |
| 197 | + |
| 198 | + |
100 | 199 | def test_dual_agents(
|
101 | 200 | multimodal_llms: List[BaseChatModel],
|
102 | 201 | tool_calling_models: List[BaseChatModel],
|
@@ -183,6 +282,7 @@ def test_models(
|
183 | 282 | # for each bench configuration seperate run folder
|
184 | 283 | now = datetime.now()
|
185 | 284 | run_name = f"run_{now.strftime('%Y-%m-%d_%H-%M-%S')}"
|
| 285 | + run_dir = Path(out_dir) / run_name |
186 | 286 | for i, model_name in enumerate(model_names):
|
187 | 287 | for u in range(bench_conf.repeats):
|
188 | 288 | curr_out_dir = (
|
@@ -240,8 +340,20 @@ def test_models(
|
240 | 340 | tasks=vlm_tasks,
|
241 | 341 | bench_logger=bench_logger,
|
242 | 342 | )
|
| 343 | + |
243 | 344 | except Exception as e:
|
244 | 345 | bench_logger.critical(f"BENCHMARK RUN FAILED: {e}")
|
245 | 346 | bench_logger.critical(
|
246 | 347 | f"{bench_conf.name} benchmark for {model_name}, vendor: {vendors[i]}, execution number: {u + 1}"
|
247 | 348 | )
|
| 349 | + # TODO (mkotynia): resolve unbound bench_logger |
| 350 | + bench_logger.info(f"Merging summaries for benchmark: {bench_conf.name}") |
| 351 | + |
| 352 | + for model_name in model_names: |
| 353 | + merge_model_repeats_summary(bench_conf.name, model_name, run_dir) |
| 354 | + |
| 355 | + merge_benchmark_summary(bench_conf.name, run_dir, model_names) |
| 356 | + |
| 357 | + bench_logger.info( |
| 358 | + f"Summary merging completed for benchmark: {bench_conf.name}" |
| 359 | + ) |
0 commit comments