|
16 | 16 | import logging
|
17 | 17 | from typing import List, Type
|
18 | 18 |
|
19 |
| -from pydantic import BaseModel, Field |
| 19 | +from pydantic import Field |
20 | 20 | from rai.messages import preprocess_image
|
21 | 21 |
|
22 |
| -from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask |
| 22 | +from rai_bench.vlm_benchmark.interfaces import ( |
| 23 | + ImageReasoningAnswer, |
| 24 | + ImageReasoningTask, |
| 25 | + ImageReasoningTaskInput, |
| 26 | +) |
23 | 27 |
|
24 | 28 | loggers_type = logging.Logger
|
25 | 29 |
|
26 | 30 |
|
27 |
| -class BoolAnswerWithJustification(BaseModel): |
| 31 | +class BoolAnswerWithJustification(ImageReasoningAnswer[bool]): |
28 | 32 | """A boolean answer to the user question along with justification for the answer."""
|
29 | 33 |
|
30 |
| - answer: bool |
31 |
| - justification: str |
32 | 34 |
|
| 35 | +class QuantityAnswerWithJustification(ImageReasoningAnswer[int]): |
| 36 | + """A quantity answer telling the number of objects to the user question along with justification for the answer.""" |
33 | 37 |
|
34 |
| -class QuantityAnswerWithJustification(BaseModel): |
35 |
| - """A quantity answer to the user question along with justification for the answer.""" |
36 | 38 |
|
37 |
| - answer: int |
38 |
| - justification: str |
39 |
| - |
40 |
| - |
41 |
| -class MultipleChoiceAnswerWithJustification(BaseModel): |
| 39 | +class MultipleChoiceAnswerWithJustification(ImageReasoningAnswer[List[str]]): |
42 | 40 | """A multiple choice answer to the user question along with justification for the answer."""
|
43 | 41 |
|
44 |
| - answer: List[str] |
45 |
| - justification: str |
46 | 42 |
|
47 |
| - |
48 |
| -class BoolImageTaskInput(BaseModel): |
49 |
| - question: str = Field(..., description="The question to be answered.") |
50 |
| - images_paths: List[str] = Field( |
51 |
| - ..., |
52 |
| - description="List of image file paths to be used for answering the question.", |
53 |
| - ) |
54 |
| - expected_answer: bool = Field( |
55 |
| - ..., description="The expected answer to the question." |
56 |
| - ) |
| 43 | +class BoolImageTaskInput(ImageReasoningTaskInput[bool]): |
| 44 | + """Input for a task that requires a boolean answer to a question about an image.""" |
57 | 45 |
|
58 | 46 |
|
59 |
| -class QuantityImageTaskInput(BaseModel): |
| 47 | +class QuantityImageTaskInput(ImageReasoningTaskInput[int]): |
60 | 48 | """Input for a task that requires counting objects in an image."""
|
61 | 49 |
|
62 |
| - question: str = Field(..., description="The question to be answered.") |
63 |
| - images_paths: List[str] = Field( |
64 |
| - ..., |
65 |
| - description="List of image file paths to be used for answering the question.", |
66 |
| - ) |
67 |
| - expected_answer: int = Field( |
68 |
| - ..., |
69 |
| - description="The expected number of objects to be counted in the image.", |
70 |
| - ) |
71 | 50 |
|
72 |
| - |
73 |
| -class MultipleChoiceImageTaskInput(BaseModel): |
| 51 | +class MultipleChoiceImageTaskInput(ImageReasoningTaskInput[List[str]]): |
74 | 52 | """Input for a task that requires selecting one or more answers from a list of options."""
|
75 | 53 |
|
76 |
| - question: str = Field(..., description="The question to be answered.") |
77 |
| - images_paths: List[str] = Field( |
78 |
| - ..., |
79 |
| - description="List of image file paths to be used for answering the question.", |
80 |
| - ) |
81 | 54 | options: List[str] = Field(
|
82 | 55 | ...,
|
83 | 56 | description="List of possible answers to the question.",
|
84 | 57 | )
|
85 |
| - expected_answer: List[str] = Field( |
86 |
| - ..., |
87 |
| - description="The expected answer to the question being a list of strings chosen from the options.", |
88 |
| - ) |
89 | 58 |
|
90 | 59 |
|
91 | 60 | class BoolImageTask(ImageReasoningTask[BoolAnswerWithJustification]):
|
|
0 commit comments