Skip to content

Commit 8deca92

Browse files
authored
Fix: Resolve Agentic failure loop with improved json parsing (#1541)
Fixes #1538 - Add final instruction for better json schema adherence and reduce "filler" text - Add `extract_json` function to identify json by pairs of `[]` or `{}` - Add unit tests for `extract_json` **Notes:** In my testing, these changes vastly reduce the number of times the 'repair' agent is triggered. However, this does not fix _all_ parse errors: 1. In some cases, json is valid json, but keys are invalid 2. In some cases, especially with more complex Pydantic models (e.g., `Themes`, which is composed of `List[Theme]`), the Pydantic json schema includes `$defs`. As a result, some LLMs will return `$defs` or include `"$schema": "https://json-schema.org/draft/...`; either of these additional keys will cause the Pydantic parser to fail.
1 parent 849f3e3 commit 8deca92

File tree

3 files changed

+135
-3
lines changed

3 files changed

+135
-3
lines changed

src/ragas/prompt/pydantic_prompt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ragas.exceptions import RagasOutputParserException
1717

1818
from .base import BasePrompt, StringIO, _check_if_language_is_supported
19-
from .utils import get_all_strings, update_strings
19+
from .utils import extract_json, get_all_strings, update_strings
2020

2121
if t.TYPE_CHECKING:
2222
from langchain_core.callbacks import Callbacks
@@ -82,6 +82,7 @@ def to_string(self, data: t.Optional[InputModel] = None) -> str:
8282
if data is not None
8383
else "input: (None)\n"
8484
)
85+
+ "Respond only with a valid JSON object that complies with the specified schema.\n"
8586
+ "output: "
8687
)
8788

@@ -393,7 +394,8 @@ async def parse_output_string(
393394
):
394395
callbacks = callbacks or []
395396
try:
396-
result = super().parse(output_string)
397+
jsonstr = extract_json(output_string)
398+
result = super().parse(jsonstr)
397399
except OutputParserException:
398400
if max_retries != 0:
399401
retry_rm, retry_cb = new_group(

src/ragas/prompt/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,43 @@ def replace_string(s: str) -> str:
6464
return {k: update_strings(v, old_strings, new_strings) for k, v in obj.items()}
6565

6666
return copy.deepcopy(obj)
67+
68+
69+
def extract_json(text: str) -> str:
70+
"""Identify json from a text blob by matching '[]' or '{}'.
71+
72+
Warning: This will identify the first json structure!"""
73+
74+
# check for markdown indicator; if present, start there
75+
md_json_idx = text.find("```json")
76+
if md_json_idx != -1:
77+
text = text[md_json_idx:]
78+
79+
# search for json delimiter pairs
80+
left_bracket_idx = text.find("[")
81+
left_brace_idx = text.find("{")
82+
83+
indices = [idx for idx in (left_bracket_idx, left_brace_idx) if idx != -1]
84+
start_idx = min(indices) if indices else None
85+
86+
# If no delimiter found, return the original text
87+
if start_idx is None:
88+
return text
89+
90+
# Identify the exterior delimiters defining JSON
91+
open_char = text[start_idx]
92+
close_char = "]" if open_char == "[" else "}"
93+
94+
# Initialize a count to keep track of delimiter pairs
95+
count = 0
96+
for i, char in enumerate(text[start_idx:], start=start_idx):
97+
if char == open_char:
98+
count += 1
99+
elif char == close_char:
100+
count -= 1
101+
102+
# When count returns to zero, we've found a complete structure
103+
if count == 0:
104+
return text[start_idx : i + 1]
105+
106+
return text # In case of unbalanced JSON, return the original text

tests/unit/prompt/test_prompt_utils.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from pydantic import BaseModel
55

6-
from ragas.prompt.utils import get_all_strings, update_strings
6+
from ragas.prompt.utils import extract_json, get_all_strings, update_strings
77

88

99
class Category(BaseModel):
@@ -122,3 +122,93 @@ def test_update_strings(obj, old_strings, new_strings):
122122

123123
assert get_all_strings(updated_obj) == new_strings
124124
assert get_all_strings(obj) == old_strings
125+
126+
127+
class TestExtractJson:
128+
prefix = "Here's the generated abstract conceptual question in the requested JSON format: "
129+
suffix = "Would you like me to explain in more detail?"
130+
object = """{"key": "value"}"""
131+
array = """[1, 2, 3]"""
132+
nested = """{"outer": {"inner": [1, 2, 3]}}"""
133+
134+
test_cases = [
135+
(object, object),
136+
(array, array),
137+
(nested, nested),
138+
(prefix + object, object),
139+
(object + suffix, object),
140+
(prefix + object + suffix, object),
141+
(prefix + array, array),
142+
(array + suffix, array),
143+
(prefix + array + suffix, array),
144+
(prefix + nested, nested),
145+
(nested + suffix, nested),
146+
(prefix + nested + suffix, nested),
147+
(object + array + nested, object),
148+
(nested + object + array, nested),
149+
]
150+
151+
@pytest.mark.parametrize("text, expected", test_cases)
152+
def test_extract_json(self, text, expected):
153+
assert extract_json(text) == expected
154+
155+
def test_extract_empty_array(self):
156+
text = "Here is an empty array: [] and some text."
157+
expected = "[]"
158+
assert extract_json(text) == expected
159+
160+
def test_extract_empty_object(self):
161+
text = "Here is an empty object: {} and more text."
162+
expected = "{}"
163+
assert extract_json(text) == expected
164+
165+
def test_extract_incomplete_json(self):
166+
text = 'Not complete: {"key": "value", "array": [1, 2, 3'
167+
expected = 'Not complete: {"key": "value", "array": [1, 2, 3'
168+
assert extract_json(text) == expected
169+
170+
def test_markdown_json(self):
171+
text = """
172+
```python
173+
import json
174+
175+
def modify_query(input_data):
176+
query = input_data["query"]
177+
style = input_data["style"]
178+
length = input_data["length"]
179+
180+
if style == "Poor grammar":
181+
# Poor grammar modifications (simplified for brevity)
182+
query = query.replace("How", "how")
183+
query = query.replace("do", "does")
184+
query = query.replace("terms of", "in terms of")
185+
query = query.replace("and", "")
186+
187+
if length == "long":
188+
# Long text modifications (simplified for brevity)
189+
query += "?"
190+
191+
return {
192+
"text": query
193+
}
194+
195+
input_data = {
196+
"query": "How can the provided commands be used to manage and troubleshoot namespaces in a Kubernetes environment?",
197+
"style": "Poor grammar",
198+
"length": "long"
199+
}
200+
201+
output = modify_query(input_data)
202+
print(json.dumps(output, indent=4))
203+
```
204+
205+
Output:
206+
```json
207+
{"text": "how does the provided commands be used to manage and troubleshoot namespaces in a Kubernetes environment?"}
208+
```
209+
This Python function `modify_query` takes an input dictionary with query, style, and length as keys. It applies modifications based on the specified style (Poor grammar) and length (long). The modified query is then returned as a JSON object.
210+
211+
Note: This implementation is simplified for brevity and may not cover all possible edge cases or nuances of natural language processing.
212+
"""
213+
expected = """{"text": "how does the provided commands be used to manage and troubleshoot namespaces in a Kubernetes environment?"}"""
214+
assert extract_json(text) == expected

0 commit comments

Comments
 (0)