Skip to content

Commit 0430e8f

Browse files
authored
fix: improved answer relevancy (#346)
1 parent af55f18 commit 0430e8f

File tree

6 files changed

+80
-35
lines changed

6 files changed

+80
-35
lines changed

src/ragas/llms/langchain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool:
2525
def isBedrock(llm: BaseLLM | BaseChatModel) -> bool:
2626
return isinstance(llm, Bedrock) or isinstance(llm, BedrockChat)
2727

28+
2829
def isAmazonAPIGateway(llm: BaseLLM | BaseChatModel) -> bool:
2930
return isinstance(llm, AmazonAPIGateway)
3031

src/ragas/llms/llamaindex.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
if t.TYPE_CHECKING:
1111
from langchain.callbacks.base import Callbacks
1212
from langchain.prompts import ChatPromptTemplate
13+
1314
try:
1415
from llama_index.llms.base import LLM as LiLLM
1516
except ImportError:

src/ragas/metrics/_answer_relevance.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ragas.embeddings.base import embedding_factory
1313
from ragas.exceptions import OpenAIKeyNotFound
1414
from ragas.metrics.base import EvaluationMode, MetricWithLLM
15+
from ragas.utils import load_as_json
1516

1617
if t.TYPE_CHECKING:
1718
from langchain.callbacks.manager import CallbackManager
@@ -21,13 +22,46 @@
2122

2223
QUESTION_GEN = HumanMessagePromptTemplate.from_template(
2324
"""
24-
Generate question for the given answer.
25-
Answer:\nThe PSLV-C56 mission is scheduled to be launched on Sunday, 30 July 2023 at 06:30 IST / 01:00 UTC. It will be launched from the Satish Dhawan Space Centre, Sriharikota, Andhra Pradesh, India
26-
Question: When is the scheduled launch date and time for the PSLV-C56 mission, and where will it be launched from?
25+
Generate a question for the given answer and Identify if answer is noncommittal
2726
28-
Answer:{answer}
29-
Question:
30-
""" # noqa: E501
27+
Answer:
28+
Albert Einstein was born in Germany.
29+
Context:
30+
Albert Einstein was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time
31+
Output:
32+
{{"question":"Where was Albert Einstein born?","noncommittal":false}}
33+
34+
35+
Answer:
36+
It can change its skin color based on the temperature of its environment.
37+
Context:
38+
A recent scientific study has discovered a new species of frog in the Amazon rainforest that has the unique ability to change its skin color based on the temperature of its environment.
39+
Output:
40+
{{"question":"What unique ability does the newly discovered species of frog have?","noncommittal":false}}
41+
42+
43+
Answer:
44+
Everest
45+
Context:
46+
The tallest mountain on Earth, measured from sea level, is a renowned peak located in the Himalayas.
47+
Output:
48+
{{"question":"What is the tallest mountain on Earth?","noncommittal":false}}
49+
50+
51+
Answer:
52+
I don't know about the groundbreaking feature of the smartphone invented in 2023 as am unware of information beyong 2022.
53+
Context:
54+
In 2023, a groundbreaking invention was announced: a smartphone with a battery life of one month, revolutionizing the way people use mobile technology.
55+
Output:
56+
{{"question":"What was the groundbreaking feature of the smartphone invented in 2023?", "noncommittal":true}}
57+
58+
59+
60+
Answer:
61+
{answer}
62+
Context:
63+
{context}
64+
Output:""" # noqa: E501
3165
)
3266

3367

@@ -53,7 +87,7 @@ class AnswerRelevancy(MetricWithLLM):
5387
"""
5488

5589
name: str = "answer_relevancy"
56-
evaluation_mode: EvaluationMode = EvaluationMode.qa
90+
evaluation_mode: EvaluationMode = EvaluationMode.qac
5791
batch_size: int = 15
5892
strictness: int = 3
5993
embeddings: RagasEmbeddings = field(default_factory=embedding_factory)
@@ -71,29 +105,31 @@ def _score_batch(
71105
callbacks: t.Optional[CallbackManager] = None,
72106
callback_group_name: str = "batch",
73107
) -> list[float]:
74-
questions, answers = dataset["question"], dataset["answer"]
108+
questions, answers, contexts = (
109+
dataset["question"],
110+
dataset["answer"],
111+
dataset["contexts"],
112+
)
75113
with trace_as_chain_group(
76114
callback_group_name, callback_manager=callbacks
77115
) as batch_group:
78116
prompts = []
79-
for ans in answers:
80-
human_prompt = QUESTION_GEN.format(answer=ans)
117+
for ans, ctx in zip(answers, contexts):
118+
human_prompt = QUESTION_GEN.format(answer=ans, context="\n".join(ctx))
81119
prompts.append(ChatPromptTemplate.from_messages([human_prompt]))
82120

83121
results = self.llm.generate(
84122
prompts,
85123
n=self.strictness,
86124
callbacks=batch_group,
87125
)
88-
results = [[i.text for i in r] for r in results.generations]
89-
126+
results = [[load_as_json(i.text) for i in r] for r in results.generations]
90127
scores = []
91-
for question, gen_questions in zip(questions, results):
92-
if question is not None and question != "" and len(gen_questions) > 0:
93-
cosine_sim = self.calculate_similarity(question, gen_questions)
94-
scores.append(cosine_sim.mean())
95-
else:
96-
scores.append(0.0)
128+
for question, result in zip(questions, results):
129+
gen_questions = [item.get("question", "") for item in result]
130+
committal = np.any([item.get("noncommittal", False) for item in result])
131+
cosine_sim = self.calculate_similarity(question, gen_questions)
132+
scores.append(cosine_sim.mean() * int(not committal))
97133

98134
return scores
99135

src/ragas/testset/testset_generator.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,16 @@
5858
"conditional": "_condition_question",
5959
}
6060

61-
DataRow = namedtuple("DataRow", ["question", "ground_truth_context", "ground_truth", "question_type", "episode_done"])
61+
DataRow = namedtuple(
62+
"DataRow",
63+
[
64+
"question",
65+
"ground_truth_context",
66+
"ground_truth",
67+
"question_type",
68+
"episode_done",
69+
],
70+
)
6271

6372

6473
@dataclass
@@ -73,11 +82,11 @@ def to_pandas(self) -> pd.DataFrame:
7382
data_samples = []
7483
for data in self.test_data:
7584
data = {
76-
"question": data.question,
77-
"ground_truth_context": data.ground_truth_context,
78-
"ground_truth": data.ground_truth,
79-
"question_type": data.question_type,
80-
"episode_done": data.episode_done,
85+
"question": data.question,
86+
"ground_truth_context": data.ground_truth_context,
87+
"ground_truth": data.ground_truth,
88+
"question_type": data.question_type,
89+
"episode_done": data.episode_done,
8190
}
8291
data_samples.append(data)
8392

@@ -394,11 +403,13 @@ def generate(
394403
context = self._generate_context(question, text_chunk)
395404
is_conv = len(context) > 1
396405
answer = self._generate_answer(question, context)
397-
for i, (qstn, ctx, ans) in enumerate(zip(question.split("\n"), context, answer)):
398-
episode_done = False if is_conv and i==0 else True
406+
for i, (qstn, ctx, ans) in enumerate(
407+
zip(question.split("\n"), context, answer)
408+
):
409+
episode_done = False if is_conv and i == 0 else True
399410
samples.append(
400-
DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
401-
)
411+
DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
412+
)
402413
count += 1
403414
pbar.update(count)
404415

src/ragas/validation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@ def remap_column_names(dataset: Dataset, column_map: dict[str, str]) -> Dataset:
99
"""
1010
Remap the column names in case dataset uses different column names
1111
"""
12-
12+
1313
inverse_column_map = {v: k for k, v in column_map.items()}
14-
return dataset.rename_columns(
15-
inverse_column_map
16-
)
14+
return dataset.rename_columns(inverse_column_map)
1715

1816

1917
def validate_column_dtypes(ds: Dataset):

tests/unit/test_validation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def test_column_remap(column_map):
103103
}
104104
)
105105
remapped_dataset = remap_column_names(TEST_DATASET, column_map)
106-
assert all(
107-
col in remapped_dataset.column_names for col in column_map.keys()
108-
)
106+
assert all(col in remapped_dataset.column_names for col in column_map.keys())
109107

110108

111109
def test_column_remap_omit():

0 commit comments

Comments
 (0)