Skip to content

Commit 1adc3d4

Browse files
author
Magdalena Kotynia
committed
fix: fixed typing after refactor extracting common logic for vlm tasks inputs and answers
1 parent 466ec59 commit 1adc3d4

File tree

2 files changed

+35
-36
lines changed

2 files changed

+35
-36
lines changed

src/rai_bench/rai_bench/vlm_benchmark/interfaces.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from abc import ABC, abstractmethod
17-
from typing import Generic, List, Literal, Optional, TypeVar
17+
from typing import Any, Generic, List, Literal, Optional, TypeVar
1818

1919
from langchain_core.messages import BaseMessage
2020
from langchain_core.runnables.config import DEFAULT_RECURSION_LIMIT
@@ -28,29 +28,6 @@
2828
IMAGE_REASONING_SYSTEM_PROMPT = "You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response in requested structured output format."
2929

3030

31-
class LangchainRawOutputModel(BaseModel):
32-
"""
33-
A Pydantic model for wrapping Langchain message parsing results from a structured output agent. See documentation for more details:
34-
https://github.com/langchain-ai/langchain/blob/02001212b0a2b37d90451d8493089389ea220cab/libs/core/langchain_core/language_models/chat_models.py#L1430-L1432
35-
36-
37-
Attributes
38-
----------
39-
raw : BaseMessage
40-
The original raw message object from Langchain before parsing.
41-
parsed : BaseModel
42-
The parsed and validated Pydantic model instance derived from the raw message.
43-
parsing_error : Optional[BaseException]
44-
Any exception that occurred during the parsing process, None if parsing
45-
was successful.
46-
"""
47-
48-
model_config = ConfigDict(arbitrary_types_allowed=True)
49-
raw: BaseMessage
50-
parsed: BaseModel
51-
parsing_error: Optional[BaseException]
52-
53-
5431
class TaskValidationError(Exception):
5532
pass
5633

@@ -78,7 +55,30 @@ class ImageReasoningAnswer(BaseModel, Generic[AnswerT]):
7855
justification: str = Field(..., description="Justification for the answer.")
7956

8057

81-
class ImageReasoningTask(ABC, Generic[BaseModelT]):
58+
class LangchainRawOutputModel(BaseModel):
59+
"""
60+
A Pydantic model for wrapping Langchain message parsing results from a structured output agent. See documentation for more details:
61+
https://github.com/langchain-ai/langchain/blob/02001212b0a2b37d90451d8493089389ea220cab/libs/core/langchain_core/language_models/chat_models.py#L1430-L1432
62+
63+
64+
Attributes
65+
----------
66+
raw : BaseMessage
67+
The original raw message object from Langchain before parsing.
68+
parsed : BaseModel
69+
The parsed and validated Pydantic model instance derived from the raw message.
70+
parsing_error : Optional[BaseException]
71+
Any exception that occurred during the parsing process, None if parsing
72+
was successful.
73+
"""
74+
75+
model_config = ConfigDict(arbitrary_types_allowed=True)
76+
raw: BaseMessage
77+
parsed: ImageReasoningAnswer[Any]
78+
parsing_error: Optional[BaseException]
79+
80+
81+
class ImageReasoningTask(ABC, Generic[AnswerT]):
8282
complexity: Literal["easy", "medium", "hard"]
8383
recursion_limit: int = DEFAULT_RECURSION_LIMIT
8484

@@ -103,13 +103,14 @@ def __init__(
103103
self.logger = logging.getLogger(__name__)
104104
self.question: str
105105
self.images_paths: List[str]
106+
# TODO move here task input
106107

107108
def set_logger(self, logger: loggers_type):
108109
self.logger = logger
109110

110111
@property
111112
@abstractmethod
112-
def structured_output(self) -> type[BaseModelT]:
113+
def structured_output(self) -> type[ImageReasoningAnswer[AnswerT]]:
113114
"""Structured output that agent should return."""
114115
pass
115116

@@ -141,7 +142,7 @@ def get_prompt(self) -> str:
141142
pass
142143

143144
@abstractmethod
144-
def validate(self, output: BaseModelT) -> float:
145+
def validate(self, output: ImageReasoningAnswer[AnswerT]) -> float:
145146
"""Validate result of the task."""
146147
pass
147148

@@ -158,7 +159,7 @@ def get_images(self) -> List[str]:
158159

159160
def get_structured_output_from_messages(
160161
self, messages: List[BaseMessage]
161-
) -> BaseModelT | None:
162+
) -> ImageReasoningAnswer[AnswerT] | None:
162163
"""Extract and validate structured output from a list of messages.
163164
164165
Iterates through messages in reverse order, attempting to find the message that is

src/rai_bench/rai_bench/vlm_benchmark/tasks/tasks.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class MultipleChoiceImageTaskInput(ImageReasoningTaskInput[List[str]]):
5757
)
5858

5959

60-
class BoolImageTask(ImageReasoningTask[BoolAnswerWithJustification]):
60+
class BoolImageTask(ImageReasoningTask[bool]):
6161
complexity = "easy"
6262

6363
def __init__(
@@ -87,11 +87,11 @@ def get_images(self):
8787
images = [preprocess_image(image_path) for image_path in self.images_paths]
8888
return images
8989

90-
def validate(self, output: BoolAnswerWithJustification) -> float:
90+
def validate(self, output: ImageReasoningAnswer[bool]) -> float:
9191
return float(output.answer == self.expected_answer)
9292

9393

94-
class QuantityImageTask(ImageReasoningTask[QuantityAnswerWithJustification]):
94+
class QuantityImageTask(ImageReasoningTask[int]):
9595
"""A task that requires counting objects in an image."""
9696

9797
complexity = "medium"
@@ -114,7 +114,7 @@ def type(self) -> str:
114114
def structured_output(self) -> Type[QuantityAnswerWithJustification]:
115115
return QuantityAnswerWithJustification
116116

117-
def validate(self, output: QuantityAnswerWithJustification) -> float:
117+
def validate(self, output: ImageReasoningAnswer[int]) -> float:
118118
return float(output.answer == self.expected_answer)
119119

120120
def get_prompt(self) -> str:
@@ -125,9 +125,7 @@ def get_images(self):
125125
return images
126126

127127

128-
class MultipleChoiceImageTask(
129-
ImageReasoningTask[MultipleChoiceAnswerWithJustification]
130-
):
128+
class MultipleChoiceImageTask(ImageReasoningTask[List[str]]):
131129
"""A task that requires selecting one or more answers from a set of options."""
132130

133131
complexity = "hard"
@@ -151,7 +149,7 @@ def type(self) -> str:
151149
def structured_output(self) -> Type[MultipleChoiceAnswerWithJustification]:
152150
return MultipleChoiceAnswerWithJustification
153151

154-
def validate(self, output: MultipleChoiceAnswerWithJustification) -> float:
152+
def validate(self, output: ImageReasoningAnswer[List[str]]) -> float:
155153
answers_processed = set([answer.casefold() for answer in output.answer])
156154
expected_processed = set([answer.casefold() for answer in self.expected_answer])
157155

0 commit comments

Comments
 (0)