Skip to content

Commit 38d09dc

Browse files
authored
Fix: Limit number of retries for parse failures (#1569)
When parsing of an LLM response fails, the invalid output is sent to the LLM to be fixed. This PR threads the number of retries through this call, preventing unbounded recursion. The old `max_retries` wasn't preventing this due to `generate()` and `parse_output_string()` being co-recursive via the call to `generate()` here https://github.com/explodinggradients/ragas/blob/ade46fb7c0b5dffb76ef26d876ff021ded9dfa96/src/ragas/prompt/pydantic_prompt.py#L406 The result was a prompt that would keep growing through recursive calls (with nested versions becoming increasingly more deeply quoted) until the prompt was too big for the LLM to process. Addresses #1538
1 parent 2cf7622 commit 38d09dc

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

src/ragas/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ class RagasOutputParserException(RagasException):
2626
Exception raised when the output parser fails to parse the output.
2727
"""
2828

29-
def __init__(self, num_retries: int):
29+
def __init__(self):
3030
msg = (
31-
f"The output parser failed to parse the output after {num_retries} retries."
31+
"The output parser failed to parse the output including retries."
3232
)
3333
super().__init__(msg)
3434

src/ragas/prompt/pydantic_prompt.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ async def generate(
9393
temperature: t.Optional[float] = None,
9494
stop: t.Optional[t.List[str]] = None,
9595
callbacks: t.Optional[Callbacks] = None,
96+
retries_left: int = 3,
9697
) -> OutputModel:
9798
"""
9899
Generate a single output using the provided language model and input data.
@@ -111,6 +112,8 @@ async def generate(
111112
A list of stop sequences to end generation.
112113
callbacks : Callbacks, optional
113114
Callback functions to be called during the generation process.
115+
retries_left : int, optional
116+
Number of retry attempts for an invalid LLM response
114117
115118
Returns
116119
-------
@@ -131,6 +134,7 @@ async def generate(
131134
temperature=temperature,
132135
stop=stop,
133136
callbacks=callbacks,
137+
retries_left=retries_left,
134138
)
135139
return output_single[0]
136140

@@ -142,6 +146,7 @@ async def generate_multiple(
142146
temperature: t.Optional[float] = None,
143147
stop: t.Optional[t.List[str]] = None,
144148
callbacks: t.Optional[Callbacks] = None,
149+
retries_left: int = 3,
145150
) -> t.List[OutputModel]:
146151
"""
147152
Generate multiple outputs using the provided language model and input data.
@@ -160,6 +165,8 @@ async def generate_multiple(
160165
A list of stop sequences to end generation.
161166
callbacks : Callbacks, optional
162167
Callback functions to be called during the generation process.
168+
retries_left : int, optional
169+
Number of retry attempts for an invalid LLM response
163170
164171
Returns
165172
-------
@@ -198,7 +205,7 @@ async def generate_multiple(
198205
prompt_value=prompt_value,
199206
llm=llm,
200207
callbacks=prompt_cb,
201-
max_retries=3,
208+
retries_left=retries_left,
202209
)
203210
processed_output = self.process_output(answer, data) # type: ignore
204211
output_models.append(processed_output)
@@ -390,14 +397,14 @@ async def parse_output_string(
390397
prompt_value: PromptValue,
391398
llm: BaseRagasLLM,
392399
callbacks: Callbacks,
393-
max_retries: int = 1,
400+
retries_left: int = 1,
394401
):
395402
callbacks = callbacks or []
396403
try:
397404
jsonstr = extract_json(output_string)
398405
result = super().parse(jsonstr)
399406
except OutputParserException:
400-
if max_retries != 0:
407+
if retries_left != 0:
401408
retry_rm, retry_cb = new_group(
402409
name="fix_output_format",
403410
inputs={"output_string": output_string},
@@ -410,17 +417,12 @@ async def parse_output_string(
410417
prompt_value=prompt_value.to_string(),
411418
),
412419
callbacks=retry_cb,
420+
retries_left = retries_left - 1,
413421
)
414422
retry_rm.on_chain_end({"fixed_output_string": fixed_output_string})
415-
return await self.parse_output_string(
416-
output_string=fixed_output_string.text,
417-
prompt_value=prompt_value,
418-
llm=llm,
419-
max_retries=max_retries - 1,
420-
callbacks=callbacks,
421-
)
423+
result = fixed_output_string
422424
else:
423-
raise RagasOutputParserException(num_retries=max_retries)
425+
raise RagasOutputParserException()
424426
return result
425427

426428

tests/unit/test_prompt.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from langchain_core.outputs import Generation, LLMResult
55
from langchain_core.prompt_values import StringPromptValue
6+
from pydantic import BaseModel
67

78
from ragas.llms.base import BaseRagasLLM
89
from ragas.prompt import StringIO, StringPrompt
@@ -203,3 +204,25 @@ def test_prompt_class_attributes():
203204
p.examples = []
204205
assert p.instruction != p_another_instance.instruction
205206
assert p.examples != p_another_instance.examples
207+
208+
209+
@pytest.mark.asyncio
210+
async def test_prompt_parse_retry():
211+
from ragas.prompt import PydanticPrompt, StringIO
212+
from ragas.exceptions import RagasOutputParserException
213+
214+
class OutputModel(BaseModel):
215+
example: str
216+
217+
class Prompt(PydanticPrompt[StringIO, OutputModel]):
218+
instruction = ""
219+
input_model = StringIO
220+
output_model = OutputModel
221+
222+
echo_llm = EchoLLM(run_config=RunConfig())
223+
prompt = Prompt()
224+
with pytest.raises(RagasOutputParserException):
225+
await prompt.generate(
226+
data=StringIO(text="this prompt will be echoed back as invalid JSON"),
227+
llm=echo_llm,
228+
)

0 commit comments

Comments
 (0)