Skip to content

Commit 30c6b6a

Browse files
refactor: added task id, refactor of storing task_input variables
1 parent 01179b6 commit 30c6b6a

File tree

5 files changed

+119
-22
lines changed

5 files changed

+119
-22
lines changed

src/rai_bench/rai_bench/test_models.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,19 @@
2626
import rai_bench.manipulation_o3de as manipulation_o3de
2727
import rai_bench.tool_calling_agent as tool_calling_agent
2828
import rai_bench.vlm_benchmark as vlm_benchmark
29-
from rai_bench.base_benchmark import ModelSummary, RunSummary
30-
from rai_bench.results_processing.data_loading import SUMMARY_FILE_NAME
29+
from rai_bench.base_benchmark import ModelSummary, RunSummary, TasksSummary
30+
from rai_bench.results_processing.data_loading import (
31+
DETAILED_FILE_NAME,
32+
SUMMARY_FILE_NAME,
33+
)
3134
from rai_bench.utils import (
3235
define_benchmark_logger,
3336
get_llm_for_benchmark,
3437
get_llm_model_name,
3538
)
3639

3740
REPEATS_SUMMARY_FILE_NAME = "repeats_summary.csv"
41+
TASKS_SUMMARY_FILE_NAME = "tasks_summary.csv"
3842
BENCHMARK_SUMMARY = "benchmark_summary.csv"
3943

4044

@@ -151,7 +155,7 @@ def merge_model_repeats_summary(
151155

152156
merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
153157
with open(merged_file, "w", newline="") as f:
154-
writer = csv.DictWriter(f, fieldnames=RunSummary.model_fields.keys())
158+
writer = csv.DictWriter(f, fieldnames=ModelSummary.model_fields.keys())
155159
writer.writeheader()
156160
writer.writerow(merged_summary.model_dump())
157161

@@ -174,7 +178,7 @@ def merge_benchmark_summary(
174178
if not bench_dir.exists():
175179
return
176180

177-
all_summaries: List[RunSummary] = []
181+
all_summaries: List[ModelSummary] = []
178182
for model_name in model_names:
179183
model_dir = bench_dir / model_name
180184
merged_file = model_dir / REPEATS_SUMMARY_FILE_NAME
@@ -183,19 +187,89 @@ def merge_benchmark_summary(
183187
with open(merged_file, "r") as f:
184188
reader = csv.DictReader(f)
185189
for row in reader:
186-
all_summaries.append(RunSummary.model_validate(row))
190+
all_summaries.append(ModelSummary.model_validate(row))
187191

188192
if not all_summaries:
189193
return
190194

191195
benchmark_summary_file = bench_dir / BENCHMARK_SUMMARY
192196
with open(benchmark_summary_file, "w", newline="") as f:
193-
writer = csv.DictWriter(f, fieldnames=RunSummary.model_fields.keys())
197+
writer = csv.DictWriter(f, fieldnames=ModelSummary.model_fields.keys())
194198
writer.writeheader()
195199
for summary in all_summaries:
196200
writer.writerow(summary.model_dump())
197201

198202

203+
def merge_tasks_summary(bench_name: str, model_name: str, run_dir: Path) -> None:
204+
"""Merge task results across all repeats for a single model, aggregating by task.
205+
206+
Parameters
207+
----------
208+
bench_name : str
209+
Name of the benchmark
210+
model_name : str
211+
Name of the model
212+
run_dir : Path
213+
Directory containing the benchmark run results
214+
"""
215+
model_dir = run_dir / bench_name / model_name
216+
if not model_dir.exists():
217+
return
218+
219+
# Collect all task results from all repeats
220+
task_data_by_prompt: Dict[str, Dict[str, List[float]]] = {}
221+
222+
for repeat_dir in model_dir.iterdir():
223+
if repeat_dir.is_dir() and repeat_dir.name.isdigit():
224+
results_file = repeat_dir / DETAILED_FILE_NAME
225+
if results_file.exists():
226+
# Read detailed results from this repeat
227+
with open(results_file, "r") as f:
228+
reader = csv.DictReader(f)
229+
for row in reader:
230+
task_prompt = row["task_prompt"]
231+
score = float(row["score"])
232+
total_time = float(row["total_time"])
233+
234+
if task_prompt not in task_data_by_prompt:
235+
task_data_by_prompt[task_prompt] = {
236+
"scores": [],
237+
"times": [],
238+
}
239+
240+
task_data_by_prompt[task_prompt]["scores"].append(score)
241+
task_data_by_prompt[task_prompt]["times"].append(total_time)
242+
243+
if not task_data_by_prompt:
244+
return
245+
246+
# Calculate statistics for each task
247+
task_summaries: List[TasksSummary] = []
248+
for task_prompt, data in task_data_by_prompt.items():
249+
scores = np.array(data["scores"])
250+
times = np.array(data["times"])
251+
252+
task_summary = TasksSummary(
253+
model_name=model_name,
254+
task_prompt=task_prompt,
255+
avg_success_rate=round(float(scores.mean()), 3),
256+
std_success_rate=round(float(scores.std()), 3),
257+
avg_time=round(float(times.mean()), 3),
258+
std_time=round(float(times.std()), 3),
259+
repeats=len(scores), # TODO (mkotynia) (extract repeats in another way)
260+
)
261+
task_summaries.append(task_summary)
262+
263+
# Save task summaries to CSV
264+
tasks_summary_file = model_dir / TASKS_SUMMARY_FILE_NAME
265+
with open(tasks_summary_file, "w", newline="") as f:
266+
if task_summaries:
267+
writer = csv.DictWriter(f, fieldnames=TasksSummary.model_fields.keys())
268+
writer.writeheader()
269+
for task_summary in task_summaries:
270+
writer.writerow(task_summary.model_dump())
271+
272+
199273
def test_dual_agents(
200274
multimodal_llms: List[BaseChatModel],
201275
tool_calling_models: List[BaseChatModel],
@@ -351,6 +425,7 @@ def test_models(
351425

352426
for model_name in model_names:
353427
merge_model_repeats_summary(bench_conf.name, model_name, run_dir)
428+
merge_tasks_summary(bench_conf.name, model_name, run_dir)
354429

355430
merge_benchmark_summary(bench_conf.name, run_dir, model_names)
356431

src/rai_bench/rai_bench/vlm_benchmark/benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,15 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
133133
score = task.validate(output=structured_output)
134134
else:
135135
errors.append(f"Not valid structured output: {type(structured_output)}")
136-
score = False
136+
score = 0
137137

138138
te = time.perf_counter()
139139
total_time = te - ts
140140

141141
self.logger.info(f"TASK SCORE: {score}, TOTAL TIME: {total_time:.3f}")
142142

143143
task_result = TaskResult(
144+
task_id=task.task_id,
144145
task_prompt=task.get_prompt(),
145146
system_prompt=task.get_system_prompt(),
146147
type=task.type,

src/rai_bench/rai_bench/vlm_benchmark/interfaces.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import hashlib
1516
import logging
1617
from abc import ABC, abstractmethod
1718
from typing import Any, Generic, List, Literal, Optional, TypeVar
1819

1920
from langchain_core.messages import BaseMessage
2021
from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT
21-
from pydantic import BaseModel, ConfigDict, Field, ValidationError
22+
from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field
2223

2324
loggers_type = logging.Logger
2425

@@ -47,6 +48,13 @@ class ImageReasoningTaskInput(BaseModel, Generic[AnswerT]):
4748
..., description="The expected answer to the question."
4849
)
4950

51+
@computed_field
52+
@property
53+
def task_id(self) -> str:
54+
"""Unique identifier for the task based on question and image paths."""
55+
content = f"{self.question}|{sorted(self.images_paths)}"
56+
return hashlib.sha256(content.encode()).hexdigest()
57+
5058

5159
class ImageReasoningAnswer(BaseModel, Generic[AnswerT]):
5260
"""Base answer for an image reasoning task."""
@@ -84,6 +92,7 @@ class ImageReasoningTask(ABC, Generic[AnswerT]):
8492

8593
def __init__(
8694
self,
95+
task_input: ImageReasoningTaskInput[AnswerT],
8796
logger: loggers_type | None = None,
8897
) -> None:
8998
"""
@@ -101,9 +110,28 @@ def __init__(
101110
self.logger = logger
102111
else:
103112
self.logger = logging.getLogger(__name__)
104-
self.question: str
105-
self.images_paths: List[str]
106-
# TODO move here task input
113+
114+
self._task_input = task_input
115+
116+
@property
117+
def question(self) -> str:
118+
"""The question to be answered."""
119+
return self._task_input.question
120+
121+
@property
122+
def images_paths(self) -> List[str]:
123+
"""List of image file paths."""
124+
return self._task_input.images_paths
125+
126+
@property
127+
def expected_answer(self) -> AnswerT:
128+
"""The expected answer to the question."""
129+
return self._task_input.expected_answer
130+
131+
@property
132+
def task_id(self) -> str:
133+
"""Unique identifier for the task."""
134+
return self._task_input.task_id
107135

108136
def set_logger(self, logger: loggers_type):
109137
self.logger = logger

src/rai_bench/rai_bench/vlm_benchmark/results_tracking.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
class TaskResult(BaseModel):
22+
task_id: str = Field(..., description="Unique identifier for the task object.")
2223
task_prompt: str = Field(..., description="The task prompt.")
2324
system_prompt: str = Field(..., description="The system prompt.")
2425
complexity: str = Field(..., description="Complexity of the task.")

src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ def __init__(
6666
logger: loggers_type | None = None,
6767
) -> None:
6868
super().__init__(
69+
task_input=task_input,
6970
logger=logger,
7071
)
71-
self.question = task_input.question
72-
self.images_paths = task_input.images_paths
73-
self.expected_answer = task_input.expected_answer
7472

7573
@property
7674
def structured_output(self) -> type[BoolAnswerWithJustification]:
@@ -101,10 +99,7 @@ def __init__(
10199
task_input: QuantityImageTaskInput,
102100
logger: loggers_type | None = None,
103101
) -> None:
104-
super().__init__(logger=logger)
105-
self.question = task_input.question
106-
self.images_paths = task_input.images_paths
107-
self.expected_answer = task_input.expected_answer
102+
super().__init__(task_input=task_input, logger=logger)
108103

109104
@property
110105
def type(self) -> str:
@@ -135,11 +130,8 @@ def __init__(
135130
task_input: MultipleChoiceImageTaskInput,
136131
logger: loggers_type | None = None,
137132
) -> None:
138-
super().__init__(logger=logger)
139-
self.question = task_input.question
140-
self.images_paths = task_input.images_paths
133+
super().__init__(task_input=task_input, logger=logger)
141134
self.options = task_input.options
142-
self.expected_answer = task_input.expected_answer
143135

144136
@property
145137
def type(self) -> str:

0 commit comments

Comments
 (0)