Skip to content

Commit 88ef2e0

Browse files
authored
fix: add support for both langchain llm and ragas llm (#2229)
## Issue Link / Problem Description <!-- Link to related issue or describe the problem this PR solves --> - Fixes #2085 ## Changes Made <!-- Describe what you changed and why --> - Handle both LangChain LLMs and Ragas LLMs.
1 parent e7884f8 commit 88ef2e0

File tree

3 files changed

+115
-37
lines changed

3 files changed

+115
-37
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ _experiments/
169169
**/fil-result/
170170
src/ragas/_version.py
171171
experimental/ragas_experimental/_version.py
172+
examples/ragas_examples/_version.py
172173
.vscode
173174
.envrc
174175
uv.lock

src/ragas/prompt/multi_modal_prompt.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,44 @@
1212
from urllib.parse import urlparse
1313

1414
import requests
15+
from langchain_core.language_models import BaseLanguageModel
1516
from langchain_core.messages import BaseMessage, HumanMessage
1617
from langchain_core.prompt_values import PromptValue
1718
from PIL import Image
1819
from pydantic import BaseModel
20+
from typing_extensions import TypedDict
1921

2022
from ragas.callbacks import ChainType, new_group
2123
from ragas.exceptions import RagasOutputParserException
22-
from ragas.prompt.pydantic_prompt import PydanticPrompt, RagasOutputParser
24+
from ragas.prompt.pydantic_prompt import (
25+
PydanticPrompt,
26+
RagasOutputParser,
27+
is_langchain_llm,
28+
)
2329

2430
if t.TYPE_CHECKING:
2531
from langchain_core.callbacks import Callbacks
2632

27-
from ragas.llms.base import BaseRagasLLM
33+
from ragas.llms.base import BaseRagasLLM
2834

2935
# type variables for input and output models
3036
InputModel = t.TypeVar("InputModel", bound=BaseModel)
3137
OutputModel = t.TypeVar("OutputModel", bound=BaseModel)
3238

39+
40+
# Specific typed dictionaries for message content
41+
class TextContent(TypedDict):
42+
type: t.Literal["text"]
43+
text: str
44+
45+
46+
class ImageUrlContent(TypedDict):
47+
type: t.Literal["image_url"]
48+
image_url: dict[str, str]
49+
50+
51+
MessageContent = t.Union[TextContent, ImageUrlContent]
52+
3353
logger = logging.getLogger(__name__)
3454

3555
# --- Constants for Security Policy ---
@@ -101,7 +121,7 @@ def to_prompt_value(self, data: t.Optional[InputModel] = None):
101121

102122
async def generate_multiple(
103123
self,
104-
llm: BaseRagasLLM,
124+
llm: t.Union[BaseRagasLLM, BaseLanguageModel],
105125
data: InputModel,
106126
n: int = 1,
107127
temperature: t.Optional[float] = None,
@@ -146,26 +166,47 @@ async def generate_multiple(
146166
metadata={"type": ChainType.RAGAS_PROMPT},
147167
)
148168
prompt_value = self.to_prompt_value(processed_data)
149-
resp = await llm.generate(
150-
prompt_value,
151-
n=n,
152-
temperature=temperature,
153-
stop=stop,
154-
callbacks=prompt_cb,
155-
)
169+
170+
# Handle both LangChain LLMs and Ragas LLMs
171+
# LangChain LLMs have agenerate() for async, generate() for sync
172+
# Ragas LLMs have generate() as async method
173+
if is_langchain_llm(llm):
174+
# This is a LangChain LLM - use agenerate_prompt()
175+
langchain_llm = t.cast(BaseLanguageModel, llm)
176+
resp = await langchain_llm.agenerate_prompt(
177+
[prompt_value],
178+
stop=stop,
179+
callbacks=prompt_cb,
180+
)
181+
else:
182+
# This is a Ragas LLM - use generate()
183+
ragas_llm = t.cast(BaseRagasLLM, llm)
184+
resp = await ragas_llm.generate(
185+
prompt_value,
186+
n=n,
187+
temperature=temperature,
188+
stop=stop,
189+
callbacks=prompt_cb,
190+
)
156191

157192
output_models = []
158193
parser = RagasOutputParser(pydantic_object=self.output_model) # type: ignore
159194
for i in range(n):
160195
output_string = resp.generations[0][i].text
161196
try:
162-
answer = await parser.parse_output_string(
163-
output_string=output_string,
164-
prompt_value=prompt_value, # type: ignore
165-
llm=llm,
166-
callbacks=prompt_cb,
167-
retries_left=retries_left,
168-
)
197+
# For the parser, we need a BaseRagasLLM, so if it's a LangChain LLM, we need to handle this
198+
if is_langchain_llm(llm):
199+
# Skip parsing retry for LangChain LLMs since parser expects BaseRagasLLM
200+
answer = self.output_model.model_validate_json(output_string)
201+
else:
202+
ragas_llm = t.cast(BaseRagasLLM, llm)
203+
answer = await parser.parse_output_string(
204+
output_string=output_string,
205+
prompt_value=prompt_value, # type: ignore
206+
llm=ragas_llm,
207+
callbacks=prompt_cb,
208+
retries_left=retries_left,
209+
)
169210
processed_output = self.process_output(answer, data) # type: ignore
170211
output_models.append(processed_output)
171212
except RagasOutputParserException as e:
@@ -204,7 +245,7 @@ def to_messages(self) -> t.List[BaseMessage]:
204245
# Return empty list or handle as appropriate if all items failed processing
205246
return []
206247

207-
def _securely_process_item(self, item: str) -> t.Optional[t.Dict[str, t.Any]]:
248+
def _securely_process_item(self, item: str) -> t.Optional[MessageContent]:
208249
"""
209250
Securely determines if an item is text, a valid image data URI,
210251
or a fetchable image URL according to policy. Returns the appropriate
@@ -258,11 +299,11 @@ def _looks_like_image_path(self, item: str) -> bool:
258299
_, ext = os.path.splitext(path_part)
259300
return ext.lower() in COMMON_IMAGE_EXTENSIONS
260301

261-
def _get_text_payload(self, text: str) -> dict:
302+
def _get_text_payload(self, text: str) -> TextContent:
262303
"""Returns the standard payload for text content."""
263304
return {"type": "text", "text": text}
264305

265-
def _get_image_payload(self, mime_type: str, encoded_image: str) -> dict:
306+
def _get_image_payload(self, mime_type: str, encoded_image: str) -> ImageUrlContent:
266307
"""Returns the standard payload for image content."""
267308
# Ensure mime_type is safe and starts with "image/"
268309
if not mime_type or not mime_type.lower().startswith("image/"):

src/ragas/prompt/pydantic_prompt.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing as t
99

1010
from langchain_core.exceptions import OutputParserException
11+
from langchain_core.language_models import BaseLanguageModel
1112
from langchain_core.output_parsers import PydanticOutputParser
1213
from langchain_core.prompt_values import StringPromptValue as PromptValue
1314
from pydantic import BaseModel
@@ -22,7 +23,21 @@
2223
if t.TYPE_CHECKING:
2324
from langchain_core.callbacks import Callbacks
2425

25-
from ragas.llms.base import BaseRagasLLM
26+
from ragas.llms.base import BaseRagasLLM
27+
28+
29+
def is_langchain_llm(llm: t.Union[BaseRagasLLM, BaseLanguageModel]) -> bool:
30+
"""
31+
Detect if an LLM is a LangChain LLM or a Ragas LLM.
32+
33+
Args:
34+
llm: The LLM instance to check
35+
36+
Returns:
37+
True if it's a LangChain LLM, False if it's a Ragas LLM
38+
"""
39+
return hasattr(llm, "agenerate") and not hasattr(llm, "run_config")
40+
2641

2742
logger = logging.getLogger(__name__)
2843

@@ -87,7 +102,7 @@ def to_string(self, data: t.Optional[InputModel] = None) -> str:
87102

88103
async def generate(
89104
self,
90-
llm: BaseRagasLLM,
105+
llm: t.Union[BaseRagasLLM, BaseLanguageModel],
91106
data: InputModel,
92107
temperature: t.Optional[float] = None,
93108
stop: t.Optional[t.List[str]] = None,
@@ -139,7 +154,7 @@ async def generate(
139154

140155
async def generate_multiple(
141156
self,
142-
llm: BaseRagasLLM,
157+
llm: t.Union[BaseRagasLLM, BaseLanguageModel],
143158
data: InputModel,
144159
n: int = 1,
145160
temperature: t.Optional[float] = None,
@@ -187,26 +202,47 @@ async def generate_multiple(
187202
metadata={"type": ChainType.RAGAS_PROMPT},
188203
)
189204
prompt_value = PromptValue(text=self.to_string(processed_data))
190-
resp = await llm.generate(
191-
prompt_value,
192-
n=n,
193-
temperature=temperature,
194-
stop=stop,
195-
callbacks=prompt_cb,
196-
)
205+
206+
# Handle both LangChain LLMs and Ragas LLMs
207+
# LangChain LLMs have agenerate() for async, generate() for sync
208+
# Ragas LLMs have generate() as async method
209+
if is_langchain_llm(llm):
210+
# This is a LangChain LLM - use agenerate_prompt()
211+
langchain_llm = t.cast(BaseLanguageModel, llm)
212+
resp = await langchain_llm.agenerate_prompt(
213+
[prompt_value],
214+
stop=stop,
215+
callbacks=prompt_cb,
216+
)
217+
else:
218+
# This is a Ragas LLM - use generate()
219+
ragas_llm = t.cast(BaseRagasLLM, llm)
220+
resp = await ragas_llm.generate(
221+
prompt_value,
222+
n=n,
223+
temperature=temperature,
224+
stop=stop,
225+
callbacks=prompt_cb,
226+
)
197227

198228
output_models = []
199229
parser = RagasOutputParser(pydantic_object=self.output_model)
200230
for i in range(n):
201231
output_string = resp.generations[0][i].text
202232
try:
203-
answer = await parser.parse_output_string(
204-
output_string=output_string,
205-
prompt_value=prompt_value,
206-
llm=llm,
207-
callbacks=prompt_cb,
208-
retries_left=retries_left,
209-
)
233+
# For the parser, we need a BaseRagasLLM, so if it's a LangChain LLM, we need to handle this
234+
if is_langchain_llm(llm):
235+
# Skip parsing retry for LangChain LLMs since parser expects BaseRagasLLM
236+
answer = self.output_model.model_validate_json(output_string)
237+
else:
238+
ragas_llm = t.cast(BaseRagasLLM, llm)
239+
answer = await parser.parse_output_string(
240+
output_string=output_string,
241+
prompt_value=prompt_value,
242+
llm=ragas_llm,
243+
callbacks=prompt_cb,
244+
retries_left=retries_left,
245+
)
210246
processed_output = self.process_output(answer, data) # type: ignore
211247
output_models.append(processed_output)
212248
except RagasOutputParserException as e:

0 commit comments

Comments
 (0)