14
14
15
15
16
16
import logging
17
- from typing import List
17
+ from typing import List , Type
18
18
19
19
from pydantic import BaseModel , Field
20
20
from rai .messages import preprocess_image
@@ -31,6 +31,20 @@ class BoolAnswerWithJustification(BaseModel):
31
31
justification : str
32
32
33
33
34
+ class QuantityAnswerWithJustification (BaseModel ):
35
+ """A quantity answer to the user question along with justification for the answer."""
36
+
37
+ answer : int
38
+ justification : str
39
+
40
+
41
+ class MultipleChoiceAnswerWithJustification (BaseModel ):
42
+ """A multiple choice answer to the user question along with justification for the answer."""
43
+
44
+ answer : List [str ]
45
+ justification : str
46
+
47
+
34
48
class BoolImageTaskInput (BaseModel ):
35
49
question : str = Field (..., description = "The question to be answered." )
36
50
images_paths : List [str ] = Field (
@@ -42,6 +56,38 @@ class BoolImageTaskInput(BaseModel):
42
56
)
43
57
44
58
59
+ class QuantityImageTaskInput (BaseModel ):
60
+ """Input for a task that requires counting objects in an image."""
61
+
62
+ question : str = Field (..., description = "The question to be answered." )
63
+ images_paths : List [str ] = Field (
64
+ ...,
65
+ description = "List of image file paths to be used for answering the question." ,
66
+ )
67
+ expected_answer : int = Field (
68
+ ...,
69
+ description = "The expected number of objects to be counted in the image." ,
70
+ )
71
+
72
+
73
+ class MultipleChoiceImageTaskInput (BaseModel ):
74
+ """Input for a task that requires selecting one or more answers from a list of options."""
75
+
76
+ question : str = Field (..., description = "The question to be answered." )
77
+ images_paths : List [str ] = Field (
78
+ ...,
79
+ description = "List of image file paths to be used for answering the question." ,
80
+ )
81
+ options : List [str ] = Field (
82
+ ...,
83
+ description = "List of possible answers to the question." ,
84
+ )
85
+ expected_answer : List [str ] = Field (
86
+ ...,
87
+ description = "The expected answer to the question being a list of strings chosen from the options." ,
88
+ )
89
+
90
+
45
91
class BoolImageTask (ImageReasoningTask [BoolAnswerWithJustification ]):
46
92
complexity = "easy"
47
93
@@ -72,5 +118,89 @@ def get_images(self):
72
118
images = [preprocess_image (image_path ) for image_path in self .images_paths ]
73
119
return images
74
120
75
- def validate (self , output : BoolAnswerWithJustification ) -> bool :
76
- return output .answer == self .expected_answer
121
+ def validate (self , output : BoolAnswerWithJustification ) -> float :
122
+ return float (output .answer == self .expected_answer )
123
+
124
+
125
+ class QuantityImageTask (ImageReasoningTask [QuantityAnswerWithJustification ]):
126
+ """A task that requires counting objects in an image."""
127
+
128
+ complexity = "medium"
129
+
130
+ def __init__ (
131
+ self ,
132
+ task_input : QuantityImageTaskInput ,
133
+ logger : loggers_type | None = None ,
134
+ ) -> None :
135
+ super ().__init__ (logger = logger )
136
+ self .question = task_input .question
137
+ self .images_paths = task_input .images_paths
138
+ self .expected_answer = task_input .expected_answer
139
+
140
+ @property
141
+ def type (self ) -> str :
142
+ return "quantity_response_image_task"
143
+
144
+ @property
145
+ def structured_output (self ) -> Type [QuantityAnswerWithJustification ]:
146
+ return QuantityAnswerWithJustification
147
+
148
+ def validate (self , output : QuantityAnswerWithJustification ) -> float :
149
+ return float (output .answer == self .expected_answer )
150
+
151
+ def get_prompt (self ) -> str :
152
+ return self .question
153
+
154
+ def get_images (self ):
155
+ images = [preprocess_image (image_path ) for image_path in self .images_paths ]
156
+ return images
157
+
158
+
159
+ class MultipleChoiceImageTask (
160
+ ImageReasoningTask [MultipleChoiceAnswerWithJustification ]
161
+ ):
162
+ """A task that requires selecting one or more answers from a set of options."""
163
+
164
+ complexity = "hard"
165
+
166
+ def __init__ (
167
+ self ,
168
+ task_input : MultipleChoiceImageTaskInput ,
169
+ logger : loggers_type | None = None ,
170
+ ) -> None :
171
+ super ().__init__ (logger = logger )
172
+ self .question = task_input .question
173
+ self .images_paths = task_input .images_paths
174
+ self .options = task_input .options
175
+ self .expected_answer = task_input .expected_answer
176
+
177
+ @property
178
+ def type (self ) -> str :
179
+ return "multiple_choice_response_image_task"
180
+
181
+ @property
182
+ def structured_output (self ) -> Type [MultipleChoiceAnswerWithJustification ]:
183
+ return MultipleChoiceAnswerWithJustification
184
+
185
+ def validate (self , output : MultipleChoiceAnswerWithJustification ) -> float :
186
+ answers_processed = set ([answer .casefold () for answer in output .answer ])
187
+ expected_processed = set ([answer .casefold () for answer in self .expected_answer ])
188
+
189
+ if not answers_processed .issubset (expected_processed ):
190
+ return 0.0
191
+
192
+ correct_count = len (answers_processed .intersection (expected_processed ))
193
+ total_expected = len (expected_processed )
194
+
195
+ return float (correct_count / total_expected ) if total_expected > 0 else 0.0
196
+
197
+ def get_prompt (self ) -> str :
198
+ return (
199
+ self .question
200
+ + " Choose one or more answers from the options: "
201
+ + ", " .join (self .options )
202
+ )
203
+
204
+ def get_images (self ):
205
+ images = [preprocess_image (image_path ) for image_path in self .images_paths ]
206
+ return images
0 commit comments