Skip to content

Commit b4124d0

Browse files
refactor: created separate structure for storing summary for the model across all tasks and repeats
1 parent f2150e8 commit b4124d0

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

src/rai_bench/rai_bench/base_benchmark.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,44 @@ class RunSummary(BaseModel):
3535
total_tasks: int = Field(..., description="Total number of executed tasks.")
3636

3737

38+
class ModelSummary(BaseModel):
39+
model_name: str = Field(..., description="Name of the LLM.")
40+
avg_success_rate: float = Field(
41+
...,
42+
description="Percentage of successfully completed tasks across all repeats.",
43+
)
44+
avg_total_tasks: float = Field(
45+
..., description="Average number of tasks executed through all repeats."
46+
)
47+
avg_time: float = Field(
48+
..., description="Average time taken across all tasks and repeats."
49+
)
50+
51+
repeats: int = Field(
52+
..., description="Total number of repeats for the model for each task."
53+
)
54+
55+
56+
class TasksSummary(BaseModel):
57+
model_name: str = Field(..., description="Name of the LLM.")
58+
avg_success_rate: float = Field(
59+
..., description="Average result for task across all repeats."
60+
)
61+
std_success_rate: float = Field(
62+
..., description="Standard deviation of the success rate across all repeats."
63+
)
64+
avg_time: float = Field(
65+
..., description="Average time taken across all repeats for one task."
66+
)
67+
std_time: float = Field(
68+
...,
69+
description="Standard deviation of the time taken across all repeats for one task.",
70+
)
71+
total_tasks: int = Field(
72+
..., description="Total number of executed tasks across all repeats per task."
73+
)
74+
75+
3876
class TimeoutException(Exception):
3977
pass
4078

src/rai_bench/rai_bench/test_models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import rai_bench.manipulation_o3de as manipulation_o3de
2727
import rai_bench.tool_calling_agent as tool_calling_agent
2828
import rai_bench.vlm_benchmark as vlm_benchmark
29-
from rai_bench.base_benchmark import RunSummary
29+
from rai_bench.base_benchmark import ModelSummary, RunSummary
3030
from rai_bench.results_processing.data_loading import SUMMARY_FILE_NAME
3131
from rai_bench.utils import (
3232
define_benchmark_logger,
@@ -138,15 +138,15 @@ def merge_model_repeats_summary(
138138

139139
avg_success_rate = np.mean([s.success_rate for s in summaries])
140140
avg_time = np.mean([s.avg_time for s in summaries])
141-
total_tasks = np.min(
142-
[s.total_tasks for s in summaries]
143-
) # NOTE (mkotynia) get the minimum total tasks across repeats. If benchmark breaks for some repeat, it will be noticed in such case
144141

145-
merged_summary = RunSummary(
142+
total_tasks = np.mean([s.total_tasks for s in summaries])
143+
144+
merged_summary = ModelSummary(
146145
model_name=model_name,
147-
success_rate=round(float(avg_success_rate), 2),
146+
avg_success_rate=round(float(avg_success_rate), 2),
148147
avg_time=round(float(avg_time), 3),
149-
total_tasks=total_tasks,
148+
avg_total_tasks=round(float(total_tasks), 3),
149+
repeats=len(summaries),
150150
)
151151

152152
merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME

0 commit comments

Comments
 (0)