Skip to content

Commit c1662a7

Browse files
authored
feat: throw error when max_token limit is reached (#1549)
1 parent 8deca92 commit c1662a7

File tree

3 files changed

+86
-32
lines changed

3 files changed

+86
-32
lines changed

src/ragas/callbacks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,16 @@ def parse_run_traces(
151151
metric_traces = MetricTrace()
152152
for metric_uuid in row_trace.children:
153153
metric_trace = traces[metric_uuid]
154-
metric_traces.scores[metric_trace.name] = metric_trace.outputs["output"]
154+
metric_traces.scores[metric_trace.name] = metric_trace.outputs.get(
155+
"output", {}
156+
)
155157
# get all the prompt IO from the metric trace
156158
prompt_traces = {}
157159
for i, prompt_uuid in enumerate(metric_trace.children):
158160
prompt_trace = traces[prompt_uuid]
159161
prompt_traces[f"{i}_{prompt_trace.name}"] = {
160-
"input": prompt_trace.inputs["data"],
161-
"output": prompt_trace.outputs["output"],
162+
"input": prompt_trace.inputs.get("data", {}),
163+
"output": prompt_trace.outputs.get("output", {}),
162164
}
163165
metric_traces[f"{metric_trace.name}"] = prompt_traces
164166
parased_traces.append(metric_traces)

src/ragas/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,13 @@ def __init__(self, num_retries: int):
3131
f"The output parser failed to parse the output after {num_retries} retries."
3232
)
3333
super().__init__(msg)
34+
35+
36+
class LLMDidNotFinishException(RagasException):
37+
"""
38+
Exception raised when the LLM did not finish.
39+
"""
40+
41+
def __init__(self):
42+
msg = "The LLM generation was not completed. Please increase try increasing the max_tokens and try again."
43+
super().__init__(msg)

src/ragas/llms/base.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import logging
54
import typing as t
65
from abc import ABC, abstractmethod
76
from dataclasses import dataclass, field
8-
from functools import partial
97

108
from langchain_community.chat_models.vertexai import ChatVertexAI
119
from langchain_community.llms import VertexAI
1210
from langchain_core.language_models import BaseLanguageModel
13-
from langchain_core.outputs import Generation, LLMResult
11+
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
1412
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
1513
from langchain_openai.llms import AzureOpenAI, OpenAI
1614
from langchain_openai.llms.base import BaseOpenAI
1715

16+
from ragas.exceptions import LLMDidNotFinishException
1817
from ragas.integrations.helicone import helicone_config
19-
from ragas.run_config import RunConfig, add_async_retry, add_retry
18+
from ragas.run_config import RunConfig, add_async_retry
2019

2120
if t.TYPE_CHECKING:
2221
from langchain_core.callbacks import Callbacks
22+
from langchain_core.messages import BaseMessage
2323
from langchain_core.prompt_values import PromptValue
2424
from llama_index.core.base.llms.base import BaseLLM
2525

@@ -55,6 +55,12 @@ def get_temperature(self, n: int) -> float:
5555
"""Return the temperature to use for completion based on n."""
5656
return 0.3 if n > 1 else 1e-8
5757

58+
def is_finished(self, response: LLMResult) -> bool:
59+
logger.warning(
60+
f"is_finished not implemented for {self.__class__.__name__}. Will default to True."
61+
)
62+
return True
63+
5864
@abstractmethod
5965
def generate_text(
6066
self,
@@ -82,36 +88,27 @@ async def generate(
8288
temperature: t.Optional[float] = None,
8389
stop: t.Optional[t.List[str]] = None,
8490
callbacks: Callbacks = None,
85-
is_async: bool = True,
8691
) -> LLMResult:
8792
"""Generate text using the given event loop."""
8893

8994
if temperature is None:
9095
temperature = self.get_temperature(n)
9196

92-
if is_async:
93-
agenerate_text_with_retry = add_async_retry(
94-
self.agenerate_text, self.run_config
95-
)
96-
return await agenerate_text_with_retry(
97-
prompt=prompt,
98-
n=n,
99-
temperature=temperature,
100-
stop=stop,
101-
callbacks=callbacks,
102-
)
103-
else:
104-
loop = asyncio.get_event_loop()
105-
generate_text_with_retry = add_retry(self.generate_text, self.run_config)
106-
generate_text = partial(
107-
generate_text_with_retry,
108-
prompt=prompt,
109-
n=n,
110-
temperature=temperature,
111-
stop=stop,
112-
callbacks=callbacks,
113-
)
114-
return await loop.run_in_executor(None, generate_text)
97+
agenerate_text_with_retry = add_async_retry(
98+
self.agenerate_text, self.run_config
99+
)
100+
result = await agenerate_text_with_retry(
101+
prompt=prompt,
102+
n=n,
103+
temperature=temperature,
104+
stop=stop,
105+
callbacks=callbacks,
106+
)
107+
108+
# check there are no max_token issues
109+
if not self.is_finished(result):
110+
raise LLMDidNotFinishException()
111+
return result
115112

116113

117114
class LangchainLLMWrapper(BaseRagasLLM):
@@ -123,12 +120,57 @@ class LangchainLLMWrapper(BaseRagasLLM):
123120
"""
124121

125122
def __init__(
126-
self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None
123+
self,
124+
langchain_llm: BaseLanguageModel,
125+
run_config: t.Optional[RunConfig] = None,
126+
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
127127
):
128128
self.langchain_llm = langchain_llm
129129
if run_config is None:
130130
run_config = RunConfig()
131131
self.set_run_config(run_config)
132+
self.is_finished_parser = is_finished_parser
133+
134+
def is_finished(self, response: LLMResult) -> bool:
135+
"""
136+
Parse the response to check if the LLM finished by checking the finish_reason
137+
or stop_reason.
138+
"""
139+
if self.is_finished_parser is not None:
140+
return self.is_finished_parser(response)
141+
# if no parser is provided default to our own
142+
143+
is_finished_list = []
144+
for g in response.flatten():
145+
resp = g.generations[0][0]
146+
if resp.generation_info is not None:
147+
# generation_info is provided - so we parse that
148+
149+
# OpenAI uses "stop" to indicate that the generation is finished
150+
# and is stored in 'finish_reason' key in generation_info
151+
if resp.generation_info.get("finish_reason") is not None:
152+
is_finished_list.append(
153+
resp.generation_info.get("finish_reason") == "stop"
154+
)
155+
# provied more conditions here
156+
# https://github.com/explodinggradients/ragas/issues/1548
157+
158+
# if generation_info is empty, we parse the response_metadata
159+
# this is less reliable
160+
elif t.cast(ChatGeneration, resp).message is not None:
161+
resp_message: BaseMessage = t.cast(ChatGeneration, resp).message
162+
if resp_message.response_metadata.get("finish_reason") is not None:
163+
is_finished_list.append(
164+
resp_message.response_metadata.get("finish_reason") == "stop"
165+
)
166+
elif resp_message.response_metadata.get("stop_reason") is not None:
167+
is_finished_list.append(
168+
resp_message.response_metadata.get("stop_reason") == "end_turn"
169+
)
170+
# default to True
171+
else:
172+
is_finished_list.append(True)
173+
return all(is_finished_list)
132174

133175
def generate_text(
134176
self,

0 commit comments

Comments
 (0)