Skip to content

Commit e941823

Browse files
authored
fix: use aparse in all metrics (#831)
1 parent b976369 commit e941823

File tree

6 files changed

+43
-21
lines changed

6 files changed

+43
-21
lines changed

src/ragas/llms/output_parser.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
import typing as t
34

45
from langchain_core.exceptions import OutputParserException
@@ -8,6 +9,7 @@
89
from ragas.llms import BaseRagasLLM
910
from ragas.llms.prompt import Prompt, PromptValue
1011

12+
logger = logging.getLogger(__name__)
1113
# The get_format_instructions function is a modified version from
1214
# langchain_core.output_parser.pydantic. The original version removed the "type" json schema
1315
# property that confused some older LLMs.
@@ -53,7 +55,7 @@ def get_json_format_instructions(pydantic_object: t.Type[TBaseModel]) -> str:
5355

5456
class RagasoutputParser(PydanticOutputParser):
5557
async def aparse( # type: ignore
56-
self, result: str, prompt: PromptValue, llm: BaseRagasLLM, max_retries: int
58+
self, result: str, prompt: PromptValue, llm: BaseRagasLLM, max_retries: int = 1
5759
):
5860
try:
5961
output = super().parse(result)
@@ -66,5 +68,6 @@ async def aparse( # type: ignore
6668
result = output.generations[0][0].text
6769
return await self.aparse(result, prompt, llm, max_retries - 1)
6870
else:
71+
logger.warning("Failed to parse output. Returning None.")
6972
return None
7073
return output

src/ragas/metrics/_answer_relevance.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,13 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
157157
)
158158

159159
answers = [
160-
_output_parser.parse(result.text) for result in result.generations[0]
160+
await _output_parser.aparse(result.text, prompt, self.llm)
161+
for result in result.generations[0]
161162
]
162163
if any(answer is None for answer in answers):
163164
return np.nan
164165

166+
answers = [answer for answer in answers if answer is not None]
165167
return self._calculate_score(answers, row)
166168

167169
def adapt(self, language: str, cache_dir: str | None = None) -> None:

src/ragas/metrics/_context_entities_recall.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class ContextEntityRecall(MetricWithLLM):
135135
default_factory=lambda: TEXT_ENTITY_EXTRACTION
136136
)
137137
batch_size: int = 15
138+
max_retries: int = 1
138139

139140
def _compute_score(
140141
self, ground_truth_entities: t.Sequence[str], context_entities: t.Sequence[str]
@@ -151,17 +152,19 @@ async def get_entities(
151152
is_async: bool,
152153
) -> t.Optional[ContextEntitiesResponse]:
153154
assert self.llm is not None, "LLM is not initialized"
154-
155+
p_value = self.context_entity_recall_prompt.format(
156+
text=text,
157+
)
155158
result = await self.llm.generate(
156-
prompt=self.context_entity_recall_prompt.format(
157-
text=text,
158-
),
159+
prompt=p_value,
159160
callbacks=callbacks,
160161
is_async=is_async,
161162
)
162163

163164
result_text = result.generations[0][0].text
164-
answer = _output_parser.parse(result_text)
165+
answer = await _output_parser.aparse(
166+
result_text, p_value, self.llm, self.max_retries
167+
)
165168
if answer is None:
166169
return ContextEntitiesResponse(entities=[])
167170

src/ragas/metrics/_context_precision.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class ContextPrecision(MetricWithLLM):
8989
name: str = "context_precision" # type: ignore
9090
evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore
9191
context_precision_prompt: Prompt = field(default_factory=lambda: CONTEXT_PRECISION)
92+
max_retries: int = 1
9293

9394
def _get_row_attributes(self, row: t.Dict) -> t.Tuple[str, t.List[str], t.Any]:
9495
answer = "ground_truth"
@@ -138,20 +139,24 @@ async def _ascore(
138139
assert self.llm is not None, "LLM is not set"
139140

140141
human_prompts = self._context_precision_prompt(row)
141-
responses: t.List[str] = []
142+
responses = []
142143
for hp in human_prompts:
143144
result = await self.llm.generate(
144145
hp,
145146
n=1,
146147
callbacks=callbacks,
147148
is_async=is_async,
148149
)
149-
responses.append(result.generations[0][0].text)
150+
responses.append([result.generations[0][0].text, hp])
150151

151-
items = [_output_parser.parse(item) for item in responses]
152+
items = [
153+
await _output_parser.aparse(item, hp, self.llm, self.max_retries)
154+
for item, hp in responses
155+
]
152156
if any(item is None for item in items):
153157
return np.nan
154158

159+
items = [item for item in items if item is not None]
155160
answers = ContextPrecisionVerifications(__root__=items)
156161
score = self._calculate_average_precision(answers.__root__)
157162
return score

src/ragas/metrics/_context_recall.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class ContextRecall(MetricWithLLM):
122122
name: str = "context_recall" # type: ignore
123123
evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore
124124
context_recall_prompt: Prompt = field(default_factory=lambda: CONTEXT_RECALL_RA)
125+
max_retries: int = 1
125126

126127
def _create_context_recall_prompt(self, row: t.Dict) -> PromptValue:
127128
qstn, ctx, gt = row["question"], row["contexts"], row["ground_truth"]
@@ -142,15 +143,17 @@ def _compute_score(self, response: t.Any) -> float:
142143

143144
async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
144145
assert self.llm is not None, "set LLM before use"
145-
146+
p_value = self._create_context_recall_prompt(row)
146147
result = await self.llm.generate(
147-
self._create_context_recall_prompt(row),
148+
p_value,
148149
callbacks=callbacks,
149150
is_async=is_async,
150151
)
151152
result_text = result.generations[0][0].text
152153

153-
answers = _output_parser.parse(result_text)
154+
answers = await _output_parser.aparse(
155+
result_text, p_value, self.llm, self.max_retries
156+
)
154157
if answers is None:
155158
return np.nan
156159

src/ragas/metrics/_faithfulness.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from dataclasses import dataclass, field
77

88
import numpy as np
9-
from langchain_core.output_parsers import PydanticOutputParser
109
from langchain_core.pydantic_v1 import BaseModel, Field
1110

1211
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
@@ -85,7 +84,7 @@ def dicts(self) -> t.List[t.Dict]:
8584
_faithfulness_output_instructions = get_json_format_instructions(
8685
StatementFaithfulnessAnswers
8786
)
88-
_faithfulness_output_parser = PydanticOutputParser(
87+
_faithfulness_output_parser = RagasoutputParser(
8988
pydantic_object=StatementFaithfulnessAnswers
9089
)
9190

@@ -157,6 +156,7 @@ class Faithfulness(MetricWithLLM):
157156
nli_statements_message: Prompt = field(
158157
default_factory=lambda: NLI_STATEMENTS_MESSAGE
159158
)
159+
max_retries: int = 1
160160

161161
def _create_answer_prompt(self, row: t.Dict) -> PromptValue:
162162
question, answer = row["question"], row["answer"]
@@ -200,20 +200,26 @@ async def _ascore(
200200
returns the NLI score for each (q, c, a) pair
201201
"""
202202
assert self.llm is not None, "LLM is not set"
203-
p = self._create_answer_prompt(row)
203+
p_value = self._create_answer_prompt(row)
204204
answer_result = await self.llm.generate(
205-
p, callbacks=callbacks, is_async=is_async
205+
p_value, callbacks=callbacks, is_async=is_async
206206
)
207207
answer_result_text = answer_result.generations[0][0].text
208-
statements = _statements_output_parser.parse(answer_result_text)
208+
statements = await _statements_output_parser.aparse(
209+
answer_result_text, p_value, self.llm, self.max_retries
210+
)
209211
if statements is None:
210212
return np.nan
211213

212-
p = self._create_nli_prompt(row, statements.__root__)
213-
nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async)
214+
p_value = self._create_nli_prompt(row, statements.__root__)
215+
nli_result = await self.llm.generate(
216+
p_value, callbacks=callbacks, is_async=is_async
217+
)
214218
nli_result_text = nli_result.generations[0][0].text
215219

216-
faithfulness = _faithfulness_output_parser.parse(nli_result_text)
220+
faithfulness = await _faithfulness_output_parser.aparse(
221+
nli_result_text, p_value, self.llm, self.max_retries
222+
)
217223
if faithfulness is None:
218224
return np.nan
219225

0 commit comments

Comments
 (0)