Skip to content

Commit 27e1c24

Browse files
pberger514jjmachan
andauthored
feat(llms.json_load): Recursively load json lists (#593)
Slightly broken json are protected against by the function `ragas.llms.json_load.JsonLoader._find_outermost_json`. However, I've found that for many metrics, gpt4 can often return slightly broken json lists, for which this function returns only the first valid json. Here we wrap `_find_outermost_json` with `_load_all_jsons` which calls it recursively to load the full json list. I.e. expected output for `'{"1":"2"}, ,, {"3":"4"}]'` is `[{'1': '2'}, {'3': '4'}]` --------- Co-authored-by: jjmachan <[email protected]>
1 parent 3834fe5 commit 27e1c24

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/ragas/llms/json_load.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def _safe_load(self, text: str, llm: BaseRagasLLM, callbacks: Callbacks = None):
8383
retry = 0
8484
while retry <= self.max_retries:
8585
try:
86-
start, end = self._find_outermost_json(text)
87-
return json.loads(text[start:end])
86+
_json = self._load_all_jsons(text)
87+
return _json[0] if len(_json) == 1 else _json
8888
except ValueError:
8989
from ragas.llms.prompt import PromptValue
9090

@@ -104,8 +104,8 @@ async def _asafe_load(
104104
retry = 0
105105
while retry <= self.max_retries:
106106
try:
107-
start, end = self._find_outermost_json(text)
108-
return json.loads(text[start:end])
107+
_json = self._load_all_jsons(text)
108+
return _json[0] if len(_json) == 1 else _json
109109
except ValueError:
110110
from ragas.llms.prompt import PromptValue
111111

@@ -126,7 +126,7 @@ async def safe_load(
126126
callbacks: Callbacks = None,
127127
is_async: bool = True,
128128
run_config: RunConfig = RunConfig(),
129-
):
129+
) -> t.Union[t.Dict, t.List]:
130130
if is_async:
131131
_asafe_load_with_retry = add_async_retry(self._asafe_load, run_config)
132132
return await _asafe_load_with_retry(text=text, llm=llm, callbacks=callbacks)
@@ -141,6 +141,16 @@ async def safe_load(
141141
safe_load,
142142
)
143143

144+
def _load_all_jsons(self, text):
145+
start, end = self._find_outermost_json(text)
146+
_json = json.loads(text[start:end])
147+
text = text.replace(text[start:end], "", 1)
148+
start, end = self._find_outermost_json(text)
149+
if (start, end) == (-1, -1):
150+
return [_json]
151+
else:
152+
return [_json] + self._load_all_jsons(text)
153+
144154
def _find_outermost_json(self, text):
145155
stack = []
146156
start_index = -1

src/ragas/metrics/_context_precision.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ async def _ascore(
138138
await json_loader.safe_load(item, self.llm, is_async=is_async)
139139
for item in responses
140140
]
141+
json_responses = t.cast(t.List[t.Dict], json_responses)
141142
score = self._calculate_average_precision(json_responses)
142143
return score
143144

src/ragas/metrics/_faithfulness.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ async def _ascore(
187187
is_async=is_async,
188188
)
189189

190+
assert isinstance(statements, dict), "Invalid JSON response"
190191
p = self._create_nli_prompt(row, statements.get("statements", []))
191192
nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async)
192193
json_output = await json_loader.safe_load(

0 commit comments

Comments
 (0)