Skip to content

Commit b455475

Browse files
Retry for failed json loading (#364)
Co-authored-by: Jithin James <[email protected]>
1 parent 303bbca commit b455475

File tree

6 files changed

+142
-11
lines changed

6 files changed

+142
-11
lines changed

src/ragas/metrics/_answer_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ragas.metrics._answer_similarity import AnswerSimilarity
1212
from ragas.metrics.base import EvaluationMode, MetricWithLLM
13-
from ragas.utils import load_as_json
13+
from ragas.utils import json_loader
1414

1515
if t.TYPE_CHECKING:
1616
from langchain.callbacks.base import Callbacks
@@ -118,7 +118,7 @@ def _score_batch(
118118

119119
f1_score = []
120120
for prediction in outputs:
121-
prediction = load_as_json(prediction[0].text)
121+
prediction = json_loader.safe_load(prediction[0].text, self.llm)
122122
prediction = [
123123
item.get(key_map[k], np.nan)
124124
for item in prediction

src/ragas/metrics/_answer_relevance.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ragas.embeddings.base import embedding_factory
1313
from ragas.exceptions import OpenAIKeyNotFound
1414
from ragas.metrics.base import EvaluationMode, MetricWithLLM
15-
from ragas.utils import load_as_json
15+
from ragas.utils import json_loader
1616

1717
if t.TYPE_CHECKING:
1818
from langchain.callbacks.base import Callbacks
@@ -125,7 +125,10 @@ def _score_batch(
125125
n=self.strictness,
126126
callbacks=batch_group,
127127
)
128-
results = [[load_as_json(i.text) for i in r] for r in results.generations]
128+
results = [
129+
[json_loader.safe_load(i.text, self.llm) for i in r]
130+
for r in results.generations
131+
]
129132
scores = []
130133
for question, result in zip(questions, results):
131134
gen_questions = [item.get("question", "") for item in result]

src/ragas/metrics/_context_precision.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
1010

1111
from ragas.metrics.base import EvaluationMode, MetricWithLLM
12-
from ragas.utils import load_as_json
12+
from ragas.utils import json_loader
1313

1414
if t.TYPE_CHECKING:
1515
from langchain.callbacks.base import Callbacks
@@ -94,7 +94,9 @@ def _score_batch(
9494
scores = []
9595

9696
for response in grouped_responses:
97-
response = [load_as_json(item) for item in sum(response, [])]
97+
response = [
98+
json_loader.safe_load(item, self.llm) for item in sum(response, [])
99+
]
98100
response = [
99101
int("yes" in resp.get("verdict", " ").lower())
100102
if resp.get("verdict")

src/ragas/metrics/_context_recall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
1010

1111
from ragas.metrics.base import EvaluationMode, MetricWithLLM
12-
from ragas.utils import load_as_json
12+
from ragas.utils import json_loader
1313

1414
if t.TYPE_CHECKING:
1515
from langchain.callbacks.base import Callbacks
@@ -118,7 +118,7 @@ def _score_batch(
118118
responses = [[i.text for i in r] for r in results.generations]
119119
scores = []
120120
for response in responses:
121-
response = load_as_json(response[0])
121+
response = json_loader.safe_load(response[0], self.llm)
122122
if response:
123123
response = [
124124
int(item.get("Attributed", "").lower() == "yes")

src/ragas/metrics/_faithfulness.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
99

1010
from ragas.metrics.base import EvaluationMode, MetricWithLLM
11-
from ragas.utils import load_as_json
11+
from ragas.utils import json_loader
1212

1313
if t.TYPE_CHECKING:
1414
from datasets import Dataset
@@ -154,7 +154,9 @@ def _score_batch(
154154

155155
prompts = []
156156
for context, output in zip(contexts, result.generations):
157-
statements = load_as_json(output[0].text).get("statements", [])
157+
statements = json_loader.safe_load(output[0].text, self.llm).get(
158+
"statements", []
159+
)
158160
statements = statements if statements != [] else ["Nil"]
159161
statements_str: str = "\n".join(
160162
[f"statement_{i+1}: {st}" for i, st in enumerate(statements)]
@@ -170,7 +172,7 @@ def _score_batch(
170172
verdict_score_map = {"yes": 1, "no": 0, "null": np.nan}
171173
scores = []
172174
for output in outputs:
173-
output = load_as_json(output[0].text)
175+
output = json_loader.safe_load(output[0].text, self.llm)
174176
output = output if output else []
175177
faithful_statements = sum(
176178
verdict_score_map.get(dict.get("verdict", "").lower(), np.nan)

src/ragas/utils.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@
22

33
import json
44
import os
5+
import typing as t
56
import warnings
7+
from dataclasses import dataclass
68
from functools import lru_cache
79

10+
from langchain.callbacks.manager import CallbackManager, trace_as_chain_group
11+
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
12+
13+
if t.TYPE_CHECKING:
14+
from ragas.llms import RagasLLM
15+
816
DEBUG_ENV_VAR = "RAGAS_DEBUG"
917
# constant to tell us that there is no key passed to the llm/embeddings
1018
NO_KEY = "no-key"
@@ -29,3 +37,119 @@ def load_as_json(text):
2937
warnings.warn(f"Invalid json: {e}")
3038

3139
return {}
40+
41+
42+
JSON_PROMPT = HumanMessagePromptTemplate.from_template(
43+
"""
44+
45+
Rewrite the input into valid json
46+
47+
48+
Input:
49+
{{
50+
"name": "John Doe",
51+
"age": 30,
52+
"isStudent": false
53+
"address": {{
54+
"street": "123 Main St",
55+
"city": "Anytown",
56+
"state": "CA",
57+
}}
58+
"hobbies": ["reading", "swimming", "cycling"]
59+
}}
60+
Output:
61+
{{
62+
"name": "John Doe",
63+
"age": 30,
64+
"isStudent": false,
65+
"address": {{
66+
"street": "123 Main St",
67+
"city": "Anytown",
68+
"state": "CA"
69+
}},
70+
"hobbies": ["reading", "swimming", "cycling"]
71+
}}
72+
73+
74+
Input:
75+
{{
76+
"statement": "The Earth is also known as "Terra" "
77+
}}
78+
Output:
79+
{{
80+
"statement": "The Earth is also known as 'Terra'"
81+
}}
82+
83+
Input:
84+
{input}
85+
86+
Output:
87+
"""
88+
)
89+
90+
91+
@dataclass
92+
class JsonLoader:
93+
max_retries: int = 2
94+
95+
def safe_load(self, text: str, llm: RagasLLM):
96+
retry = 0
97+
while retry <= self.max_retries:
98+
try:
99+
start, end = self._find_outermost_json(text)
100+
return json.loads(text[start:end])
101+
except ValueError:
102+
text = self._fix_to_json(text, llm)
103+
retry += 1
104+
105+
return {}
106+
107+
def _fix_to_json(
108+
self,
109+
text,
110+
llm,
111+
callbacks: t.Optional[CallbackManager] = None,
112+
callback_group_name: str = "batch",
113+
):
114+
# TODO (executor)
115+
with trace_as_chain_group(
116+
callback_group_name, callback_manager=callbacks
117+
) as batch_group:
118+
human_prompt = ChatPromptTemplate.from_messages(
119+
[JSON_PROMPT.format(input=text)]
120+
)
121+
results = llm.generate(
122+
[human_prompt],
123+
n=1,
124+
callbacks=batch_group,
125+
)
126+
return results.generations[0][0].text
127+
128+
def _find_outermost_json(self, text):
129+
stack = []
130+
start_index = -1
131+
132+
for i, char in enumerate(text):
133+
if char in "{[":
134+
if len(stack) == 0:
135+
start_index = i
136+
stack.append(char)
137+
138+
elif char in "}]":
139+
if len(stack) > 0:
140+
last = stack.pop()
141+
if (char == "}" and last != "{") or (char == "]" and last != "["):
142+
# Mismatched closing brace/bracket, invalid JSON
143+
break
144+
145+
if len(stack) == 0 and start_index != -1:
146+
# Found a valid outermost JSON
147+
return (
148+
start_index,
149+
i + 1,
150+
) # Add 1 to include the closing brace/bracket in the range
151+
152+
return -1, -1 # No valid JSON found
153+
154+
155+
json_loader = JsonLoader()

0 commit comments

Comments
 (0)