|
6 | 6 | from dataclasses import dataclass, field |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | | -from langchain_core.output_parsers import PydanticOutputParser |
10 | 9 | from langchain_core.pydantic_v1 import BaseModel, Field |
11 | 10 |
|
12 | 11 | from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions |
@@ -85,7 +84,7 @@ def dicts(self) -> t.List[t.Dict]: |
85 | 84 | _faithfulness_output_instructions = get_json_format_instructions( |
86 | 85 | StatementFaithfulnessAnswers |
87 | 86 | ) |
88 | | -_faithfulness_output_parser = PydanticOutputParser( |
| 87 | +_faithfulness_output_parser = RagasoutputParser( |
89 | 88 | pydantic_object=StatementFaithfulnessAnswers |
90 | 89 | ) |
91 | 90 |
|
@@ -157,6 +156,7 @@ class Faithfulness(MetricWithLLM): |
157 | 156 | nli_statements_message: Prompt = field( |
158 | 157 | default_factory=lambda: NLI_STATEMENTS_MESSAGE |
159 | 158 | ) |
| 159 | + max_retries: int = 1 |
160 | 160 |
|
161 | 161 | def _create_answer_prompt(self, row: t.Dict) -> PromptValue: |
162 | 162 | question, answer = row["question"], row["answer"] |
@@ -200,20 +200,26 @@ async def _ascore( |
200 | 200 | returns the NLI score for each (q, c, a) pair |
201 | 201 | """ |
202 | 202 | assert self.llm is not None, "LLM is not set" |
203 | | - p = self._create_answer_prompt(row) |
| 203 | + p_value = self._create_answer_prompt(row) |
204 | 204 | answer_result = await self.llm.generate( |
205 | | - p, callbacks=callbacks, is_async=is_async |
| 205 | + p_value, callbacks=callbacks, is_async=is_async |
206 | 206 | ) |
207 | 207 | answer_result_text = answer_result.generations[0][0].text |
208 | | - statements = _statements_output_parser.parse(answer_result_text) |
| 208 | + statements = await _statements_output_parser.aparse( |
| 209 | + answer_result_text, p_value, self.llm, self.max_retries |
| 210 | + ) |
209 | 211 | if statements is None: |
210 | 212 | return np.nan |
211 | 213 |
|
212 | | - p = self._create_nli_prompt(row, statements.__root__) |
213 | | - nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async) |
| 214 | + p_value = self._create_nli_prompt(row, statements.__root__) |
| 215 | + nli_result = await self.llm.generate( |
| 216 | + p_value, callbacks=callbacks, is_async=is_async |
| 217 | + ) |
214 | 218 | nli_result_text = nli_result.generations[0][0].text |
215 | 219 |
|
216 | | - faithfulness = _faithfulness_output_parser.parse(nli_result_text) |
| 220 | + faithfulness = await _faithfulness_output_parser.aparse( |
| 221 | + nli_result_text, p_value, self.llm, self.max_retries |
| 222 | + ) |
217 | 223 | if faithfulness is None: |
218 | 224 | return np.nan |
219 | 225 |
|
|
0 commit comments