|
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | from datasets import Dataset |
| 8 | +from langchain.callbacks.manager import CallbackManager, trace_as_chain_group |
| 9 | +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate |
8 | 10 |
|
9 | 11 | from ragas.metrics._answer_similarity import AnswerSimilarity |
10 | | -from ragas.metrics._faithfulness import Faithfulness |
11 | 12 | from ragas.metrics.base import EvaluationMode, MetricWithLLM |
| 13 | +from ragas.utils import load_as_json |
12 | 14 |
|
13 | | -if t.TYPE_CHECKING: |
14 | | - from langchain.callbacks.manager import CallbackManager |
| 15 | +CORRECTNESS_PROMPT = HumanMessagePromptTemplate.from_template( |
| 16 | + """ |
| 17 | +Extract following from given question and ground truth |
| 18 | +
|
| 19 | +Question:What powers the sun and what is its primary function? |
| 20 | +Answer: The sun is powered by nuclear fission, similar to nuclear reactors on Earth, and its primary function is to provide light to the solar system. |
| 21 | +Ground truth: The sun is actually powered by nuclear fusion, not fission. In its core, hydrogen atoms fuse to form helium, releasing a tremendous amount of energy. This energy is what lights up the sun and provides heat and light, essential for life on Earth. The sun's light also plays a critical role in Earth's climate system and helps to drive the weather and ocean currents. |
| 22 | +Extracted statements: |
| 23 | +[ |
| 24 | +{{ |
| 25 | + "statements that are present in both the answer and the ground truth": ["The sun's primary function is to provide light"], |
| 26 | + "statements present in the answer but not found in the ground truth": ["The sun is powered by nuclear fission", "similar to nuclear reactors on Earth"], |
| 27 | + "relevant statements found in the ground truth but omitted in the answer": ["The sun is powered by nuclear fusion, not fission", "In its core, hydrogen atoms fuse to form helium, releasing a tremendous amount of energy", "This energy provides heat and light, essential for life on Earth", "The sun's light plays a critical role in Earth's climate system", "The sun helps to drive the weather and ocean currents"] |
| 28 | +}} |
| 29 | +] |
| 30 | +
|
| 31 | +Question: What is the boiling point of water? |
| 32 | +Answer: The boiling point of water is 100 degrees Celsius at sea level. |
| 33 | +Ground truth: The boiling point of water is 100 degrees Celsius (212 degrees Fahrenheit) at sea level, but it can change with altitude. |
| 34 | +Extracted statements: |
| 35 | +[ |
| 36 | + {{ |
| 37 | + "statements that are present in both the answer and the ground truth": ["The boiling point of water is 100 degrees Celsius at sea level"], |
| 38 | + "statements present in the answer but not found in the ground truth": [], |
| 39 | + "relevant statements found in the ground truth but omitted in the answer": ["The boiling point can change with altitude", "The boiling point of water is 212 degrees Fahrenheit at sea level"] |
| 40 | + }} |
| 41 | +] |
| 42 | +
|
| 43 | +
|
| 44 | +Question:{question} |
| 45 | +Answer: {answer} |
| 46 | +Ground truth: {ground_truth} |
| 47 | +Extracted statements:""" # noqa: E501 |
| 48 | +) |
15 | 49 |
|
16 | 50 |
|
17 | 51 | @dataclass |
@@ -39,34 +73,61 @@ class AnswerCorrectness(MetricWithLLM): |
39 | 73 | name: str = "answer_correctness" |
40 | 74 | evaluation_mode: EvaluationMode = EvaluationMode.qga |
41 | 75 | batch_size: int = 15 |
42 | | - weights: list[float] = field(default_factory=lambda: [0.5, 0.5]) |
| 76 | + weights: list[float] = field(default_factory=lambda: [0.75, 0.25]) |
43 | 77 | answer_similarity: AnswerSimilarity | None = None |
44 | | - faithfulness: Faithfulness | None = None |
45 | 78 |
|
46 | 79 | def __post_init__(self: t.Self): |
47 | 80 | if self.answer_similarity is None: |
48 | 81 | self.answer_similarity = AnswerSimilarity( |
49 | 82 | llm=self.llm, batch_size=self.batch_size |
50 | 83 | ) |
51 | | - if self.faithfulness is None: |
52 | | - self.faithfulness = Faithfulness(llm=self.llm, batch_size=self.batch_size) |
53 | 84 |
|
54 | 85 | def _score_batch( |
55 | 86 | self: t.Self, |
56 | 87 | dataset: Dataset, |
57 | 88 | callbacks: t.Optional[CallbackManager] = None, |
58 | 89 | callback_group_name: str = "batch", |
59 | 90 | ) -> list[float]: |
60 | | - if "contexts" in dataset.column_names: |
61 | | - ds_faithfulness = dataset.remove_columns(["contexts"]) |
62 | | - else: |
63 | | - ds_faithfulness = dataset |
| 91 | + question, answer, ground_truths = ( |
| 92 | + dataset["question"], |
| 93 | + dataset["answer"], |
| 94 | + dataset["ground_truths"], |
| 95 | + ) |
| 96 | + prompts = [] |
| 97 | + |
| 98 | + with trace_as_chain_group( |
| 99 | + callback_group_name, callback_manager=callbacks |
| 100 | + ) as batch_group: |
| 101 | + for q, a, g in zip(question, answer, ground_truths): |
| 102 | + human_prompt = CORRECTNESS_PROMPT.format( |
| 103 | + question=q, ground_truth=g[0], answer=a |
| 104 | + ) |
| 105 | + prompts.append(ChatPromptTemplate.from_messages([human_prompt])) |
| 106 | + |
| 107 | + result = self.llm.generate(prompts, callbacks=batch_group) |
| 108 | + outputs = result.generations |
| 109 | + key_map = { |
| 110 | + "TP": "statements that are present in both the answer and the ground truth", |
| 111 | + "FP": "statements present in the answer but not found in the ground truth", |
| 112 | + "FN": "relevant statements found in the ground truth but omitted in the answer", # noqa: E501 |
| 113 | + } |
| 114 | + |
| 115 | + f1_score = [] |
| 116 | + for prediction in outputs: |
| 117 | + prediction = load_as_json(prediction[0].text) |
| 118 | + prediction = [ |
| 119 | + item.get(key_map[k], np.nan) |
| 120 | + for item in prediction |
| 121 | + for k in key_map.keys() |
| 122 | + ] |
| 123 | + tp, fp, fn = [ |
| 124 | + len(item) if isinstance(item, list) else np.nan for item in prediction |
| 125 | + ] |
| 126 | + score = tp / (tp + 0.5 * (fp + fn)) |
| 127 | + f1_score.append(score) |
64 | 128 |
|
65 | | - ds_faithfulness = ds_faithfulness.rename_columns({"ground_truths": "contexts"}) |
66 | | - faith_scores = self.faithfulness._score_batch(ds_faithfulness) # type: ignore |
67 | 129 | similarity_scores = self.answer_similarity._score_batch(dataset) # type: ignore |
68 | | - |
69 | | - scores_stacked = np.vstack([faith_scores, similarity_scores]) |
| 130 | + scores_stacked = np.vstack([f1_score, similarity_scores]) |
70 | 131 | scores = np.average( |
71 | 132 | scores_stacked, |
72 | 133 | axis=0, |
|
0 commit comments