Skip to content

Commit 0d1223c

Browse files
authored
fix: Nan in metrics/test set gen (#786)
1 parent 66d236c commit 0d1223c

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

src/ragas/metrics/_faithfulness.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
NLI_STATEMENTS_MESSAGE = Prompt(
5757
name="nli_statements",
58-
instruction="Natural language inference. Use only 'Yes' (1), 'No' (0) and 'Null' (-1) as verdict.",
58+
instruction="Natural language inference. Use only 'Yes' (1), 'No' (0)",
5959
examples=[
6060
{
6161
"context": """John is a student at XYZ University. He is pursuing a degree in Computer Science. He is enrolled in several courses this semester, including Data Structures, Algorithms, and Database Management. John is a diligent student and spends a significant amount of time studying and completing assignments. He often stays late in the library to work on his projects.""",
@@ -97,15 +97,6 @@
9797
"verdict": "0",
9898
},
9999
},
100-
{
101-
"context": """Albert Einstein was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time.""",
102-
"statements": """statement_1: Nil""",
103-
"answer": {
104-
"statement_1": "Nil",
105-
"reason": "The statement is invalid",
106-
"verdict": "-1",
107-
},
108-
},
109100
],
110101
input_keys=["context", "statements"],
111102
output_key="answer",
@@ -139,7 +130,6 @@ def _create_nli_prompt(self, row: t.Dict, statements: t.Any) -> PromptValue:
139130
contexts = row["contexts"]
140131
# check if the statements are support in the contexts
141132
contexts_str: str = "\n".join(contexts)
142-
statements = statements if statements != [] else ["Nil"]
143133
statements_str: str = "\n".join(
144134
[f"statement_{i+1}: {st}" for i, st in enumerate(statements)]
145135
)
@@ -150,7 +140,7 @@ def _create_nli_prompt(self, row: t.Dict, statements: t.Any) -> PromptValue:
150140

151141
def _compute_score(self, output: t.Any):
152142
# check the verdicts and compute the score
153-
verdict_score_map = {"1": 1, "0": 0, "-1": np.nan}
143+
verdict_score_map = {"1": 1, "0": 0}
154144
output = output if isinstance(output, list) else [output]
155145
faithful_statements = sum(
156146
verdict_score_map.get(
@@ -190,14 +180,20 @@ async def _ascore(
190180
)
191181

192182
statements = statements if isinstance(statements, dict) else {}
193-
p = self._create_nli_prompt(row, statements.get("statements", []))
194-
nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async)
195-
json_output = await json_loader.safe_load(
196-
text=nli_result.generations[0][0].text,
197-
llm=self.llm,
198-
callbacks=callbacks,
199-
is_async=is_async,
200-
)
183+
statements = statements.get("statements", [])
184+
if statements:
185+
p = self._create_nli_prompt(row, statements)
186+
nli_result = await self.llm.generate(
187+
p, callbacks=callbacks, is_async=is_async
188+
)
189+
json_output = await json_loader.safe_load(
190+
text=nli_result.generations[0][0].text,
191+
llm=self.llm,
192+
callbacks=callbacks,
193+
is_async=is_async,
194+
)
195+
else:
196+
json_output = [{}]
201197
return self._compute_score(json_output)
202198

203199
def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:

0 commit comments

Comments
 (0)