Skip to content

Commit 89c8bbf

Browse files
author
Magdalena Kotynia
committed
refactor: extracted common logic for vlm tasks inputs and answers to base ImageReasoningTaskInput and ImageReasoningAnswer classes
1 parent 0500a66 commit 89c8bbf

File tree

2 files changed

+38
-46
lines changed

2 files changed

+38
-46
lines changed

src/rai_bench/rai_bench/vlm_benchmark/interfaces.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from langchain_core.messages import BaseMessage
2020
from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT
21-
from pydantic import BaseModel, ConfigDict, ValidationError
21+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
2222

2323
loggers_type = logging.Logger
2424

@@ -55,6 +55,29 @@ class TaskValidationError(Exception):
5555
pass
5656

5757

58+
AnswerT = TypeVar("AnswerT")
59+
60+
61+
class ImageReasoningTaskInput(BaseModel, Generic[AnswerT]):
62+
"""Base input for an image reasoning task."""
63+
64+
question: str = Field(..., description="The question to be answered.")
65+
images_paths: List[str] = Field(
66+
...,
67+
description="List of image file paths to be used for answering the question.",
68+
)
69+
expected_answer: AnswerT = Field(
70+
..., description="The expected answer to the question."
71+
)
72+
73+
74+
class ImageReasoningAnswer(BaseModel, Generic[AnswerT]):
75+
"""Base answer for an image reasoning task."""
76+
77+
answer: AnswerT = Field(..., description="The answer to the question.")
78+
justification: str = Field(..., description="Justification for the answer.")
79+
80+
5881
class ImageReasoningTask(ABC, Generic[BaseModelT]):
5982
complexity: Literal["easy", "medium", "hard"]
6083
recursion_limit: int = DEFAULT_RECURSION_LIMIT

src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,76 +16,45 @@
1616
import logging
1717
from typing import List, Type
1818

19-
from pydantic import BaseModel, Field
19+
from pydantic import Field
2020
from rai.messages import preprocess_image
2121

22-
from rai_bench.vlm_benchmark.interfaces import ImageReasoningTask
22+
from rai_bench.vlm_benchmark.interfaces import (
23+
ImageReasoningAnswer,
24+
ImageReasoningTask,
25+
ImageReasoningTaskInput,
26+
)
2327

2428
loggers_type = logging.Logger
2529

2630

27-
class BoolAnswerWithJustification(BaseModel):
31+
class BoolAnswerWithJustification(ImageReasoningAnswer[bool]):
2832
"""A boolean answer to the user question along with justification for the answer."""
2933

30-
answer: bool
31-
justification: str
3234

35+
class QuantityAnswerWithJustification(ImageReasoningAnswer[int]):
36+
"""A quantity answer telling the number of objects to the user question along with justification for the answer."""
3337

34-
class QuantityAnswerWithJustification(BaseModel):
35-
"""A quantity answer to the user question along with justification for the answer."""
3638

37-
answer: int
38-
justification: str
39-
40-
41-
class MultipleChoiceAnswerWithJustification(BaseModel):
39+
class MultipleChoiceAnswerWithJustification(ImageReasoningAnswer[List[str]]):
4240
"""A multiple choice answer to the user question along with justification for the answer."""
4341

44-
answer: List[str]
45-
justification: str
4642

47-
48-
class BoolImageTaskInput(BaseModel):
49-
question: str = Field(..., description="The question to be answered.")
50-
images_paths: List[str] = Field(
51-
...,
52-
description="List of image file paths to be used for answering the question.",
53-
)
54-
expected_answer: bool = Field(
55-
..., description="The expected answer to the question."
56-
)
43+
class BoolImageTaskInput(ImageReasoningTaskInput[bool]):
44+
"""Input for a task that requires a boolean answer to a question about an image."""
5745

5846

59-
class QuantityImageTaskInput(BaseModel):
47+
class QuantityImageTaskInput(ImageReasoningTaskInput[int]):
6048
"""Input for a task that requires counting objects in an image."""
6149

62-
question: str = Field(..., description="The question to be answered.")
63-
images_paths: List[str] = Field(
64-
...,
65-
description="List of image file paths to be used for answering the question.",
66-
)
67-
expected_answer: int = Field(
68-
...,
69-
description="The expected number of objects to be counted in the image.",
70-
)
7150

72-
73-
class MultipleChoiceImageTaskInput(BaseModel):
51+
class MultipleChoiceImageTaskInput(ImageReasoningTaskInput[List[str]]):
7452
"""Input for a task that requires selecting one or more answers from a list of options."""
7553

76-
question: str = Field(..., description="The question to be answered.")
77-
images_paths: List[str] = Field(
78-
...,
79-
description="List of image file paths to be used for answering the question.",
80-
)
8154
options: List[str] = Field(
8255
...,
8356
description="List of possible answers to the question.",
8457
)
85-
expected_answer: List[str] = Field(
86-
...,
87-
description="The expected answer to the question being a list of strings chosen from the options.",
88-
)
8958

9059

9160
class BoolImageTask(ImageReasoningTask[BoolAnswerWithJustification]):

0 commit comments

Comments
 (0)