Skip to content

Commit 4d01af2

Browse files
authored
feat: improve quality of answer correctness (#339)
1 parent ba3b109 commit 4d01af2

File tree

2 files changed

+152
-59
lines changed

2 files changed

+152
-59
lines changed

experiments/assesments/metrics_assesments.ipynb

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -98,46 +98,26 @@
9898
},
9999
{
100100
"cell_type": "code",
101-
"execution_count": 6,
101+
"execution_count": 17,
102102
"id": "b3139189",
103103
"metadata": {},
104-
"outputs": [
105-
{
106-
"name": "stdout",
107-
"output_type": "stream",
108-
"text": [
109-
"{\n",
110-
" \"role\": \"assistant\",\n",
111-
" \"content\": \"How can I assist you today?\"\n",
112-
"}\n"
113-
]
114-
}
115-
],
104+
"outputs": [],
116105
"source": [
117106
"import os\n",
118-
"import openai\n",
119-
"\n",
120-
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
121-
"\n",
122-
"completion = openai.ChatCompletion.create(\n",
123-
" model=\"gpt-3.5-turbo\",\n",
124-
" messages=[\n",
125-
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
126-
" ],\n",
127-
")\n",
107+
"from openai import OpenAI\n",
128108
"\n",
129-
"print(completion.choices[0].message)"
109+
"client = OpenAI()\n"
130110
]
131111
},
132112
{
133113
"cell_type": "code",
134-
"execution_count": 7,
114+
"execution_count": 18,
135115
"id": "4bce4c53",
136116
"metadata": {},
137117
"outputs": [],
138118
"source": [
139-
"def llm2(prompt, **kwargs):\n",
140-
" response = openai.ChatCompletion.create(\n",
119+
"def llm(prompt, **kwargs):\n",
120+
" response = client.chat.completions.create(\n",
141121
" model=kwargs.get(\"model\", \"gpt-3.5-turbo\"),\n",
142122
" messages=[{\"role\": \"system\", \"content\": prompt}],\n",
143123
" temperature=kwargs.get(\"temperature\", 0),\n",
@@ -147,27 +127,12 @@
147127
" max_tokens=kwargs.get(\"max_tokens\", 500),\n",
148128
" n=kwargs.get(\"n\", 1),\n",
149129
" )\n",
150-
" return response\n",
151-
"\n",
152-
"\n",
153-
"def llm(prompt, **kwargs):\n",
154-
" response = openai.Completion.create(\n",
155-
" model=kwargs.get(\"model\", \"text-davinci-003\"),\n",
156-
" prompt=prompt,\n",
157-
" temperature=kwargs.get(\"temperature\", 0),\n",
158-
" top_p=kwargs.get(\"top_p\", 1),\n",
159-
" frequency_penalty=kwargs.get(\"frequency_penalty\", 0.0),\n",
160-
" presence_penalty=kwargs.get(\"presence_penalty\", 0.0),\n",
161-
" max_tokens=kwargs.get(\"max_tokens\", 500),\n",
162-
" logprobs=kwargs.get(\"logprobs\", 0),\n",
163-
" n=kwargs.get(\"n\", 1),\n",
164-
" )\n",
165130
" return response"
166131
]
167132
},
168133
{
169134
"cell_type": "code",
170-
"execution_count": 6,
135+
"execution_count": 19,
171136
"id": "4d9b4e31",
172137
"metadata": {},
173138
"outputs": [],
@@ -2341,11 +2306,78 @@
23412306
"results.to_dict()"
23422307
]
23432308
},
2309+
{
2310+
"cell_type": "markdown",
2311+
"id": "387bb6ea",
2312+
"metadata": {},
2313+
"source": [
2314+
"## Answer correctness"
2315+
]
2316+
},
23442317
{
23452318
"cell_type": "code",
2346-
"execution_count": null,
2319+
"execution_count": 1,
23472320
"id": "47465fd1",
23482321
"metadata": {},
2322+
"outputs": [
2323+
{
2324+
"name": "stderr",
2325+
"output_type": "stream",
2326+
"text": [
2327+
"/opt/anaconda3/envs/ragas/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
2328+
" from .autonotebook import tqdm as notebook_tqdm\n"
2329+
]
2330+
}
2331+
],
2332+
"source": [
2333+
"from ragas.metrics import answer_correctness"
2334+
]
2335+
},
2336+
{
2337+
"cell_type": "code",
2338+
"execution_count": 2,
2339+
"id": "76b13fc8",
2340+
"metadata": {},
2341+
"outputs": [],
2342+
"source": [
2343+
"data = {\"question\":\"Where is France and what's it capital?\", \"answer\":\"Asia\",\n",
2344+
" 'ground_truths':[\"France is in Europe and it's capital is Paris\"]}"
2345+
]
2346+
},
2347+
{
2348+
"cell_type": "code",
2349+
"execution_count": 3,
2350+
"id": "817f4150",
2351+
"metadata": {},
2352+
"outputs": [
2353+
{
2354+
"name": "stdout",
2355+
"output_type": "stream",
2356+
"text": [
2357+
"faith [0.0]\n",
2358+
"sim [True]\n"
2359+
]
2360+
},
2361+
{
2362+
"data": {
2363+
"text/plain": [
2364+
"0.5"
2365+
]
2366+
},
2367+
"execution_count": 3,
2368+
"metadata": {},
2369+
"output_type": "execute_result"
2370+
}
2371+
],
2372+
"source": [
2373+
"answer_correctness.score_single(data)"
2374+
]
2375+
},
2376+
{
2377+
"cell_type": "code",
2378+
"execution_count": null,
2379+
"id": "50b595cf",
2380+
"metadata": {},
23492381
"outputs": [],
23502382
"source": []
23512383
}

src/ragas/metrics/_answer_correctness.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,47 @@
55

66
import numpy as np
77
from datasets import Dataset
8+
from langchain.callbacks.manager import CallbackManager, trace_as_chain_group
9+
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
810

911
from ragas.metrics._answer_similarity import AnswerSimilarity
10-
from ragas.metrics._faithfulness import Faithfulness
1112
from ragas.metrics.base import EvaluationMode, MetricWithLLM
13+
from ragas.utils import load_as_json
1214

13-
if t.TYPE_CHECKING:
14-
from langchain.callbacks.manager import CallbackManager
15+
CORRECTNESS_PROMPT = HumanMessagePromptTemplate.from_template(
16+
"""
17+
Extract following from given question and ground truth
18+
19+
Question:What powers the sun and what is its primary function?
20+
Answer: The sun is powered by nuclear fission, similar to nuclear reactors on Earth, and its primary function is to provide light to the solar system.
21+
Ground truth: The sun is actually powered by nuclear fusion, not fission. In its core, hydrogen atoms fuse to form helium, releasing a tremendous amount of energy. This energy is what lights up the sun and provides heat and light, essential for life on Earth. The sun's light also plays a critical role in Earth's climate system and helps to drive the weather and ocean currents.
22+
Extracted statements:
23+
[
24+
{{
25+
"statements that are present in both the answer and the ground truth": ["The sun's primary function is to provide light"],
26+
"statements present in the answer but not found in the ground truth": ["The sun is powered by nuclear fission", "similar to nuclear reactors on Earth"],
27+
"relevant statements found in the ground truth but omitted in the answer": ["The sun is powered by nuclear fusion, not fission", "In its core, hydrogen atoms fuse to form helium, releasing a tremendous amount of energy", "This energy provides heat and light, essential for life on Earth", "The sun's light plays a critical role in Earth's climate system", "The sun helps to drive the weather and ocean currents"]
28+
}}
29+
]
30+
31+
Question: What is the boiling point of water?
32+
Answer: The boiling point of water is 100 degrees Celsius at sea level.
33+
Ground truth: The boiling point of water is 100 degrees Celsius (212 degrees Fahrenheit) at sea level, but it can change with altitude.
34+
Extracted statements:
35+
[
36+
{{
37+
"statements that are present in both the answer and the ground truth": ["The boiling point of water is 100 degrees Celsius at sea level"],
38+
"statements present in the answer but not found in the ground truth": [],
39+
"relevant statements found in the ground truth but omitted in the answer": ["The boiling point can change with altitude", "The boiling point of water is 212 degrees Fahrenheit at sea level"]
40+
}}
41+
]
42+
43+
44+
Question:{question}
45+
Answer: {answer}
46+
Ground truth: {ground_truth}
47+
Extracted statements:""" # noqa: E501
48+
)
1549

1650

1751
@dataclass
@@ -39,34 +73,61 @@ class AnswerCorrectness(MetricWithLLM):
3973
name: str = "answer_correctness"
4074
evaluation_mode: EvaluationMode = EvaluationMode.qga
4175
batch_size: int = 15
42-
weights: list[float] = field(default_factory=lambda: [0.5, 0.5])
76+
weights: list[float] = field(default_factory=lambda: [0.75, 0.25])
4377
answer_similarity: AnswerSimilarity | None = None
44-
faithfulness: Faithfulness | None = None
4578

4679
def __post_init__(self: t.Self):
4780
if self.answer_similarity is None:
4881
self.answer_similarity = AnswerSimilarity(
4982
llm=self.llm, batch_size=self.batch_size
5083
)
51-
if self.faithfulness is None:
52-
self.faithfulness = Faithfulness(llm=self.llm, batch_size=self.batch_size)
5384

5485
def _score_batch(
5586
self: t.Self,
5687
dataset: Dataset,
5788
callbacks: t.Optional[CallbackManager] = None,
5889
callback_group_name: str = "batch",
5990
) -> list[float]:
60-
if "contexts" in dataset.column_names:
61-
ds_faithfulness = dataset.remove_columns(["contexts"])
62-
else:
63-
ds_faithfulness = dataset
91+
question, answer, ground_truths = (
92+
dataset["question"],
93+
dataset["answer"],
94+
dataset["ground_truths"],
95+
)
96+
prompts = []
97+
98+
with trace_as_chain_group(
99+
callback_group_name, callback_manager=callbacks
100+
) as batch_group:
101+
for q, a, g in zip(question, answer, ground_truths):
102+
human_prompt = CORRECTNESS_PROMPT.format(
103+
question=q, ground_truth=g[0], answer=a
104+
)
105+
prompts.append(ChatPromptTemplate.from_messages([human_prompt]))
106+
107+
result = self.llm.generate(prompts, callbacks=batch_group)
108+
outputs = result.generations
109+
key_map = {
110+
"TP": "statements that are present in both the answer and the ground truth",
111+
"FP": "statements present in the answer but not found in the ground truth",
112+
"FN": "relevant statements found in the ground truth but omitted in the answer", # noqa: E501
113+
}
114+
115+
f1_score = []
116+
for prediction in outputs:
117+
prediction = load_as_json(prediction[0].text)
118+
prediction = [
119+
item.get(key_map[k], np.nan)
120+
for item in prediction
121+
for k in key_map.keys()
122+
]
123+
tp, fp, fn = [
124+
len(item) if isinstance(item, list) else np.nan for item in prediction
125+
]
126+
score = tp / (tp + 0.5 * (fp + fn))
127+
f1_score.append(score)
64128

65-
ds_faithfulness = ds_faithfulness.rename_columns({"ground_truths": "contexts"})
66-
faith_scores = self.faithfulness._score_batch(ds_faithfulness) # type: ignore
67129
similarity_scores = self.answer_similarity._score_batch(dataset) # type: ignore
68-
69-
scores_stacked = np.vstack([faith_scores, similarity_scores])
130+
scores_stacked = np.vstack([f1_score, similarity_scores])
70131
scores = np.average(
71132
scores_stacked,
72133
axis=0,

0 commit comments

Comments
 (0)