|
12 | 12 | from urllib.parse import urlparse |
13 | 13 |
|
14 | 14 | import requests |
| 15 | +from langchain_core.language_models import BaseLanguageModel |
15 | 16 | from langchain_core.messages import BaseMessage, HumanMessage |
16 | 17 | from langchain_core.prompt_values import PromptValue |
17 | 18 | from PIL import Image |
18 | 19 | from pydantic import BaseModel |
| 20 | +from typing_extensions import TypedDict |
19 | 21 |
|
20 | 22 | from ragas.callbacks import ChainType, new_group |
21 | 23 | from ragas.exceptions import RagasOutputParserException |
22 | | -from ragas.prompt.pydantic_prompt import PydanticPrompt, RagasOutputParser |
| 24 | +from ragas.prompt.pydantic_prompt import ( |
| 25 | + PydanticPrompt, |
| 26 | + RagasOutputParser, |
| 27 | + is_langchain_llm, |
| 28 | +) |
23 | 29 |
|
24 | 30 | if t.TYPE_CHECKING: |
25 | 31 | from langchain_core.callbacks import Callbacks |
26 | 32 |
|
27 | | - from ragas.llms.base import BaseRagasLLM |
| 33 | +from ragas.llms.base import BaseRagasLLM |
28 | 34 |
|
29 | 35 | # type variables for input and output models |
30 | 36 | InputModel = t.TypeVar("InputModel", bound=BaseModel) |
31 | 37 | OutputModel = t.TypeVar("OutputModel", bound=BaseModel) |
32 | 38 |
|
| 39 | + |
| 40 | +# Specific typed dictionaries for message content |
| 41 | +class TextContent(TypedDict): |
| 42 | + type: t.Literal["text"] |
| 43 | + text: str |
| 44 | + |
| 45 | + |
| 46 | +class ImageUrlContent(TypedDict): |
| 47 | + type: t.Literal["image_url"] |
| 48 | + image_url: dict[str, str] |
| 49 | + |
| 50 | + |
| 51 | +MessageContent = t.Union[TextContent, ImageUrlContent] |
| 52 | + |
33 | 53 | logger = logging.getLogger(__name__) |
34 | 54 |
|
35 | 55 | # --- Constants for Security Policy --- |
@@ -101,7 +121,7 @@ def to_prompt_value(self, data: t.Optional[InputModel] = None): |
101 | 121 |
|
102 | 122 | async def generate_multiple( |
103 | 123 | self, |
104 | | - llm: BaseRagasLLM, |
| 124 | + llm: t.Union[BaseRagasLLM, BaseLanguageModel], |
105 | 125 | data: InputModel, |
106 | 126 | n: int = 1, |
107 | 127 | temperature: t.Optional[float] = None, |
@@ -146,26 +166,47 @@ async def generate_multiple( |
146 | 166 | metadata={"type": ChainType.RAGAS_PROMPT}, |
147 | 167 | ) |
148 | 168 | prompt_value = self.to_prompt_value(processed_data) |
149 | | - resp = await llm.generate( |
150 | | - prompt_value, |
151 | | - n=n, |
152 | | - temperature=temperature, |
153 | | - stop=stop, |
154 | | - callbacks=prompt_cb, |
155 | | - ) |
| 169 | + |
| 170 | + # Handle both LangChain LLMs and Ragas LLMs |
| 171 | + # LangChain LLMs have agenerate() for async, generate() for sync |
| 172 | + # Ragas LLMs have generate() as async method |
| 173 | + if is_langchain_llm(llm): |
| 174 | + # This is a LangChain LLM - use agenerate_prompt() |
| 175 | + langchain_llm = t.cast(BaseLanguageModel, llm) |
| 176 | + resp = await langchain_llm.agenerate_prompt( |
| 177 | + [prompt_value], |
| 178 | + stop=stop, |
| 179 | + callbacks=prompt_cb, |
| 180 | + ) |
| 181 | + else: |
| 182 | + # This is a Ragas LLM - use generate() |
| 183 | + ragas_llm = t.cast(BaseRagasLLM, llm) |
| 184 | + resp = await ragas_llm.generate( |
| 185 | + prompt_value, |
| 186 | + n=n, |
| 187 | + temperature=temperature, |
| 188 | + stop=stop, |
| 189 | + callbacks=prompt_cb, |
| 190 | + ) |
156 | 191 |
|
157 | 192 | output_models = [] |
158 | 193 | parser = RagasOutputParser(pydantic_object=self.output_model) # type: ignore |
159 | 194 | for i in range(n): |
160 | 195 | output_string = resp.generations[0][i].text |
161 | 196 | try: |
162 | | - answer = await parser.parse_output_string( |
163 | | - output_string=output_string, |
164 | | - prompt_value=prompt_value, # type: ignore |
165 | | - llm=llm, |
166 | | - callbacks=prompt_cb, |
167 | | - retries_left=retries_left, |
168 | | - ) |
| 197 | + # For the parser, we need a BaseRagasLLM, so if it's a LangChain LLM, we need to handle this |
| 198 | + if is_langchain_llm(llm): |
| 199 | + # Skip parsing retry for LangChain LLMs since parser expects BaseRagasLLM |
| 200 | + answer = self.output_model.model_validate_json(output_string) |
| 201 | + else: |
| 202 | + ragas_llm = t.cast(BaseRagasLLM, llm) |
| 203 | + answer = await parser.parse_output_string( |
| 204 | + output_string=output_string, |
| 205 | + prompt_value=prompt_value, # type: ignore |
| 206 | + llm=ragas_llm, |
| 207 | + callbacks=prompt_cb, |
| 208 | + retries_left=retries_left, |
| 209 | + ) |
169 | 210 | processed_output = self.process_output(answer, data) # type: ignore |
170 | 211 | output_models.append(processed_output) |
171 | 212 | except RagasOutputParserException as e: |
@@ -204,7 +245,7 @@ def to_messages(self) -> t.List[BaseMessage]: |
204 | 245 | # Return empty list or handle as appropriate if all items failed processing |
205 | 246 | return [] |
206 | 247 |
|
207 | | - def _securely_process_item(self, item: str) -> t.Optional[t.Dict[str, t.Any]]: |
| 248 | + def _securely_process_item(self, item: str) -> t.Optional[MessageContent]: |
208 | 249 | """ |
209 | 250 | Securely determines if an item is text, a valid image data URI, |
210 | 251 | or a fetchable image URL according to policy. Returns the appropriate |
@@ -258,11 +299,11 @@ def _looks_like_image_path(self, item: str) -> bool: |
258 | 299 | _, ext = os.path.splitext(path_part) |
259 | 300 | return ext.lower() in COMMON_IMAGE_EXTENSIONS |
260 | 301 |
|
261 | | - def _get_text_payload(self, text: str) -> dict: |
| 302 | + def _get_text_payload(self, text: str) -> TextContent: |
262 | 303 | """Returns the standard payload for text content.""" |
263 | 304 | return {"type": "text", "text": text} |
264 | 305 |
|
265 | | - def _get_image_payload(self, mime_type: str, encoded_image: str) -> dict: |
| 306 | + def _get_image_payload(self, mime_type: str, encoded_image: str) -> ImageUrlContent: |
266 | 307 | """Returns the standard payload for image content.""" |
267 | 308 | # Ensure mime_type is safe and starts with "image/" |
268 | 309 | if not mime_type or not mime_type.lower().startswith("image/"): |
|
0 commit comments