Skip to content

Commit 6cf7773

Browse files
authored
answer-correctness : fix edge cases (#970)
fix: #959
1 parent e5e4543 commit 6cf7773

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

src/ragas/metrics/_answer_correctness.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class AnswerCorrectnessClassification(BaseModel):
6666
"The sun's light plays a critical role in Earth's climate system.",
6767
"Sunlight helps to drive the weather and ocean currents.",
6868
],
69-
"extracted_statements": AnswerCorrectnessClassification.parse_obj(
69+
"classification": AnswerCorrectnessClassification.parse_obj(
7070
{
7171
"TP": [
7272
{
@@ -114,7 +114,7 @@ class AnswerCorrectnessClassification(BaseModel):
114114
"The boiling point of water is 100 degrees Celsius (212 degrees Fahrenheit) at sea level.",
115115
"The boiling point of water can change with altitude.",
116116
],
117-
"extracted_statements": AnswerCorrectnessClassification.parse_obj(
117+
"classification": AnswerCorrectnessClassification.parse_obj(
118118
{
119119
"TP": [
120120
{
@@ -134,7 +134,7 @@ class AnswerCorrectnessClassification(BaseModel):
134134
},
135135
],
136136
input_keys=["question", "answer", "ground_truth"],
137-
output_key="extracted_statements",
137+
output_key="classification",
138138
output_type="json",
139139
)
140140

@@ -231,26 +231,36 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
231231
statements[item].dicts() if statements[item] is not None else []
232232
)
233233

234-
if any(val is [] for val in statements.values()):
235-
return np.nan
236-
237-
p_value = self.correctness_prompt.format(
238-
question=question,
239-
ground_truth=statements["ground_truth"],
240-
answer=statements["answer"],
241-
)
242-
is_statement_present = await self.llm.generate(
243-
p_value, callbacks=callbacks, is_async=is_async
244-
)
245-
result_text = is_statement_present.generations[0][0].text
234+
if not all([val == [] for val in statements.values()]):
235+
ground_truth = [
236+
statement
237+
for item in statements["ground_truth"]
238+
for statement in item["simpler_statements"]
239+
]
240+
answer = [
241+
statement
242+
for item in statements["answer"]
243+
for statement in item["simpler_statements"]
244+
]
245+
p_value = self.correctness_prompt.format(
246+
question=question,
247+
ground_truth=ground_truth,
248+
answer=answer,
249+
)
250+
is_statement_present = await self.llm.generate(
251+
p_value, callbacks=callbacks, is_async=is_async
252+
)
253+
result_text = is_statement_present.generations[0][0].text
246254

247-
answers = await _output_parser.aparse(
248-
result_text, p_value, self.llm, self.max_retries
249-
)
250-
if answers is None:
251-
return np.nan
255+
answers = await _output_parser.aparse(
256+
result_text, p_value, self.llm, self.max_retries
257+
)
258+
if answers is None:
259+
return np.nan
252260

253-
f1_score = self._compute_statement_presence(answers)
261+
f1_score = self._compute_statement_presence(answers)
262+
else:
263+
f1_score = 1.0
254264

255265
if self.weights[1] == 0:
256266
similarity_score = 0.0

0 commit comments

Comments
 (0)