Skip to content

Commit 166e697

Browse files
author
Magdalena Kotynia
committed
feat: interfaces for quantity and multiple choice tasks
1 parent 04c7900 commit 166e697

File tree

2 files changed

+134
-4
lines changed

2 files changed

+134
-4
lines changed

src/rai_bench/rai_bench/vlm_benchmark/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def get_prompt(self) -> str:
118118
pass
119119

120120
@abstractmethod
121-
def validate(self, output: BaseModelT) -> bool:
121+
def validate(self, output: BaseModelT) -> float:
122122
"""Validate result of the task."""
123123
pass
124124

src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
import logging
17-
from typing import List
17+
from typing import List, Type
1818

1919
from pydantic import BaseModel, Field
2020
from rai.messages import preprocess_image
@@ -31,6 +31,20 @@ class BoolAnswerWithJustification(BaseModel):
3131
justification: str
3232

3333

34+
class QuantityAnswerWithJustification(BaseModel):
35+
"""A quantity answer to the user question along with justification for the answer."""
36+
37+
answer: int
38+
justification: str
39+
40+
41+
class MultipleChoiceAnswerWithJustification(BaseModel):
42+
"""A multiple choice answer to the user question along with justification for the answer."""
43+
44+
answer: List[str]
45+
justification: str
46+
47+
3448
class BoolImageTaskInput(BaseModel):
3549
question: str = Field(..., description="The question to be answered.")
3650
images_paths: List[str] = Field(
@@ -42,6 +56,38 @@ class BoolImageTaskInput(BaseModel):
4256
)
4357

4458

59+
class QuantityImageTaskInput(BaseModel):
60+
"""Input for a task that requires counting objects in an image."""
61+
62+
question: str = Field(..., description="The question to be answered.")
63+
images_paths: List[str] = Field(
64+
...,
65+
description="List of image file paths to be used for answering the question.",
66+
)
67+
expected_answer: int = Field(
68+
...,
69+
description="The expected number of objects to be counted in the image.",
70+
)
71+
72+
73+
class MultipleChoiceImageTaskInput(BaseModel):
74+
"""Input for a task that requires selecting one or more answers from a list of options."""
75+
76+
question: str = Field(..., description="The question to be answered.")
77+
images_paths: List[str] = Field(
78+
...,
79+
description="List of image file paths to be used for answering the question.",
80+
)
81+
options: List[str] = Field(
82+
...,
83+
description="List of possible answers to the question.",
84+
)
85+
expected_answer: List[str] = Field(
86+
...,
87+
description="The expected answer to the question being a list of strings chosen from the options.",
88+
)
89+
90+
4591
class BoolImageTask(ImageReasoningTask[BoolAnswerWithJustification]):
4692
complexity = "easy"
4793

@@ -72,5 +118,89 @@ def get_images(self):
72118
images = [preprocess_image(image_path) for image_path in self.images_paths]
73119
return images
74120

75-
def validate(self, output: BoolAnswerWithJustification) -> bool:
76-
return output.answer == self.expected_answer
121+
def validate(self, output: BoolAnswerWithJustification) -> float:
122+
return float(output.answer == self.expected_answer)
123+
124+
125+
class QuantityImageTask(ImageReasoningTask[QuantityAnswerWithJustification]):
126+
"""A task that requires counting objects in an image."""
127+
128+
complexity = "medium"
129+
130+
def __init__(
131+
self,
132+
task_input: QuantityImageTaskInput,
133+
logger: loggers_type | None = None,
134+
) -> None:
135+
super().__init__(logger=logger)
136+
self.question = task_input.question
137+
self.images_paths = task_input.images_paths
138+
self.expected_answer = task_input.expected_answer
139+
140+
@property
141+
def type(self) -> str:
142+
return "quantity_response_image_task"
143+
144+
@property
145+
def structured_output(self) -> Type[QuantityAnswerWithJustification]:
146+
return QuantityAnswerWithJustification
147+
148+
def validate(self, output: QuantityAnswerWithJustification) -> float:
149+
return float(output.answer == self.expected_answer)
150+
151+
def get_prompt(self) -> str:
152+
return self.question
153+
154+
def get_images(self):
155+
images = [preprocess_image(image_path) for image_path in self.images_paths]
156+
return images
157+
158+
159+
class MultipleChoiceImageTask(
160+
ImageReasoningTask[MultipleChoiceAnswerWithJustification]
161+
):
162+
"""A task that requires selecting one or more answers from a set of options."""
163+
164+
complexity = "hard"
165+
166+
def __init__(
167+
self,
168+
task_input: MultipleChoiceImageTaskInput,
169+
logger: loggers_type | None = None,
170+
) -> None:
171+
super().__init__(logger=logger)
172+
self.question = task_input.question
173+
self.images_paths = task_input.images_paths
174+
self.options = task_input.options
175+
self.expected_answer = task_input.expected_answer
176+
177+
@property
178+
def type(self) -> str:
179+
return "multiple_choice_response_image_task"
180+
181+
@property
182+
def structured_output(self) -> Type[MultipleChoiceAnswerWithJustification]:
183+
return MultipleChoiceAnswerWithJustification
184+
185+
def validate(self, output: MultipleChoiceAnswerWithJustification) -> float:
186+
answers_processed = set([answer.casefold() for answer in output.answer])
187+
expected_processed = set([answer.casefold() for answer in self.expected_answer])
188+
189+
if not answers_processed.issubset(expected_processed):
190+
return 0.0
191+
192+
correct_count = len(answers_processed.intersection(expected_processed))
193+
total_expected = len(expected_processed)
194+
195+
return float(correct_count / total_expected) if total_expected > 0 else 0.0
196+
197+
def get_prompt(self) -> str:
198+
return (
199+
self.question
200+
+ " Choose one or more answers from the options: "
201+
+ ", ".join(self.options)
202+
)
203+
204+
def get_images(self):
205+
images = [preprocess_image(image_path) for image_path in self.images_paths]
206+
return images

0 commit comments

Comments
 (0)