26
26
import rai_bench .manipulation_o3de as manipulation_o3de
27
27
import rai_bench .tool_calling_agent as tool_calling_agent
28
28
import rai_bench .vlm_benchmark as vlm_benchmark
29
- from rai_bench .base_benchmark import ModelSummary , RunSummary
30
- from rai_bench .results_processing .data_loading import SUMMARY_FILE_NAME
29
+ from rai_bench .base_benchmark import ModelSummary , RunSummary , TasksSummary
30
+ from rai_bench .results_processing .data_loading import (
31
+ DETAILED_FILE_NAME ,
32
+ SUMMARY_FILE_NAME ,
33
+ )
31
34
from rai_bench .utils import (
32
35
define_benchmark_logger ,
33
36
get_llm_for_benchmark ,
34
37
get_llm_model_name ,
35
38
)
36
39
37
40
REPEATS_SUMMARY_FILE_NAME = "repeats_summary.csv"
41
+ TASKS_SUMMARY_FILE_NAME = "tasks_summary.csv"
38
42
BENCHMARK_SUMMARY = "benchmark_summary.csv"
39
43
40
44
@@ -151,7 +155,7 @@ def merge_model_repeats_summary(
151
155
152
156
merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
153
157
with open (merged_file , "w" , newline = "" ) as f :
154
- writer = csv .DictWriter (f , fieldnames = RunSummary .model_fields .keys ())
158
+ writer = csv .DictWriter (f , fieldnames = ModelSummary .model_fields .keys ())
155
159
writer .writeheader ()
156
160
writer .writerow (merged_summary .model_dump ())
157
161
@@ -174,7 +178,7 @@ def merge_benchmark_summary(
174
178
if not bench_dir .exists ():
175
179
return
176
180
177
- all_summaries : List [RunSummary ] = []
181
+ all_summaries : List [ModelSummary ] = []
178
182
for model_name in model_names :
179
183
model_dir = bench_dir / model_name
180
184
merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
@@ -183,19 +187,89 @@ def merge_benchmark_summary(
183
187
with open (merged_file , "r" ) as f :
184
188
reader = csv .DictReader (f )
185
189
for row in reader :
186
- all_summaries .append (RunSummary .model_validate (row ))
190
+ all_summaries .append (ModelSummary .model_validate (row ))
187
191
188
192
if not all_summaries :
189
193
return
190
194
191
195
benchmark_summary_file = bench_dir / BENCHMARK_SUMMARY
192
196
with open (benchmark_summary_file , "w" , newline = "" ) as f :
193
- writer = csv .DictWriter (f , fieldnames = RunSummary .model_fields .keys ())
197
+ writer = csv .DictWriter (f , fieldnames = ModelSummary .model_fields .keys ())
194
198
writer .writeheader ()
195
199
for summary in all_summaries :
196
200
writer .writerow (summary .model_dump ())
197
201
198
202
203
+ def merge_tasks_summary (bench_name : str , model_name : str , run_dir : Path ) -> None :
204
+ """Merge task results across all repeats for a single model, aggregating by task.
205
+
206
+ Parameters
207
+ ----------
208
+ bench_name : str
209
+ Name of the benchmark
210
+ model_name : str
211
+ Name of the model
212
+ run_dir : Path
213
+ Directory containing the benchmark run results
214
+ """
215
+ model_dir = run_dir / bench_name / model_name
216
+ if not model_dir .exists ():
217
+ return
218
+
219
+ # Collect all task results from all repeats
220
+ task_data_by_prompt : Dict [str , Dict [str , List [float ]]] = {}
221
+
222
+ for repeat_dir in model_dir .iterdir ():
223
+ if repeat_dir .is_dir () and repeat_dir .name .isdigit ():
224
+ results_file = repeat_dir / DETAILED_FILE_NAME
225
+ if results_file .exists ():
226
+ # Read detailed results from this repeat
227
+ with open (results_file , "r" ) as f :
228
+ reader = csv .DictReader (f )
229
+ for row in reader :
230
+ task_prompt = row ["task_prompt" ]
231
+ score = float (row ["score" ])
232
+ total_time = float (row ["total_time" ])
233
+
234
+ if task_prompt not in task_data_by_prompt :
235
+ task_data_by_prompt [task_prompt ] = {
236
+ "scores" : [],
237
+ "times" : [],
238
+ }
239
+
240
+ task_data_by_prompt [task_prompt ]["scores" ].append (score )
241
+ task_data_by_prompt [task_prompt ]["times" ].append (total_time )
242
+
243
+ if not task_data_by_prompt :
244
+ return
245
+
246
+ # Calculate statistics for each task
247
+ task_summaries : List [TasksSummary ] = []
248
+ for task_prompt , data in task_data_by_prompt .items ():
249
+ scores = np .array (data ["scores" ])
250
+ times = np .array (data ["times" ])
251
+
252
+ task_summary = TasksSummary (
253
+ model_name = model_name ,
254
+ task_prompt = task_prompt ,
255
+ avg_success_rate = round (float (scores .mean ()), 3 ),
256
+ std_success_rate = round (float (scores .std ()), 3 ),
257
+ avg_time = round (float (times .mean ()), 3 ),
258
+ std_time = round (float (times .std ()), 3 ),
259
+ repeats = len (scores ), # TODO (mkotynia) (extract repeats in another way)
260
+ )
261
+ task_summaries .append (task_summary )
262
+
263
+ # Save task summaries to CSV
264
+ tasks_summary_file = model_dir / TASKS_SUMMARY_FILE_NAME
265
+ with open (tasks_summary_file , "w" , newline = "" ) as f :
266
+ if task_summaries :
267
+ writer = csv .DictWriter (f , fieldnames = TasksSummary .model_fields .keys ())
268
+ writer .writeheader ()
269
+ for task_summary in task_summaries :
270
+ writer .writerow (task_summary .model_dump ())
271
+
272
+
199
273
def test_dual_agents (
200
274
multimodal_llms : List [BaseChatModel ],
201
275
tool_calling_models : List [BaseChatModel ],
@@ -351,6 +425,7 @@ def test_models(
351
425
352
426
for model_name in model_names :
353
427
merge_model_repeats_summary (bench_conf .name , model_name , run_dir )
428
+ merge_tasks_summary (bench_conf .name , model_name , run_dir )
354
429
355
430
merge_benchmark_summary (bench_conf .name , run_dir , model_names )
356
431
0 commit comments