Skip to content

Commit 2a4a5ad

Browse files
authored
Introduce epsilon in denominators to avoid division by zero (#1622)
This only affects `src/ragas/_topic_adherece.py`. Previously, it can be complicated to handle the score given by this metric due to possibility of division by zero. I propose to introduce an epsilon in the denominator to avoid that case.
1 parent c3a1831 commit 2a4a5ad

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

src/ragas/metrics/_topic_adherence.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ class TopicClassificationOutput(BaseModel):
4848
class TopicClassificationPrompt(
4949
PydanticPrompt[TopicClassificationInput, TopicClassificationOutput]
5050
):
51-
instruction = "Given a set of topics classify if the topic falls into any of the given reference topics."
51+
instruction = (
52+
"Given a set of topics classify if the topic falls into any of the given reference topics."
53+
)
5254
input_model = TopicClassificationInput
5355
output_model = TopicClassificationOutput
5456
examples = [
@@ -66,7 +68,9 @@ class TopicClassificationPrompt(
6668

6769

6870
class TopicRefusedPrompt(PydanticPrompt[TopicRefusedInput, TopicRefusedOutput]):
69-
instruction: str = "Given a topic, classify if the AI refused to answer the question about the topic."
71+
instruction: str = (
72+
"Given a topic, classify if the AI refused to answer the question about the topic."
73+
)
7074
input_model = TopicRefusedInput
7175
output_model = TopicRefusedOutput
7276
examples = [
@@ -93,10 +97,10 @@ class TopicRefusedPrompt(PydanticPrompt[TopicRefusedInput, TopicRefusedOutput]):
9397
]
9498

9599

96-
class TopicExtractionPrompt(
97-
PydanticPrompt[TopicExtractionInput, TopicExtractionOutput]
98-
):
99-
instruction: str = "Given an interaction between Human, Tool and AI, extract the topics from Human's input."
100+
class TopicExtractionPrompt(PydanticPrompt[TopicExtractionInput, TopicExtractionOutput]):
101+
instruction: str = (
102+
"Given an interaction between Human, Tool and AI, extract the topics from Human's input."
103+
)
100104
input_model = TopicExtractionInput
101105
output_model = TopicExtractionOutput
102106
examples = [
@@ -143,14 +147,10 @@ class TopicAdherenceScore(MetricWithLLM, MultiTurnMetric):
143147
topic_classification_prompt: PydanticPrompt = TopicClassificationPrompt()
144148
topic_refused_prompt: PydanticPrompt = TopicRefusedPrompt()
145149

146-
async def _multi_turn_ascore(
147-
self, sample: MultiTurnSample, callbacks: Callbacks
148-
) -> float:
150+
async def _multi_turn_ascore(self, sample: MultiTurnSample, callbacks: Callbacks) -> float:
149151
assert self.llm is not None, "LLM must be set"
150152
assert isinstance(sample.user_input, list), "Sample user_input must be a list"
151-
assert isinstance(
152-
sample.reference_topics, list
153-
), "Sample reference_topics must be a list"
153+
assert isinstance(sample.reference_topics, list), "Sample reference_topics must be a list"
154154
user_input = sample.pretty_repr()
155155

156156
prompt_input = TopicExtractionInput(user_input=user_input)
@@ -166,9 +166,7 @@ async def _multi_turn_ascore(
166166
data=prompt_input, llm=self.llm, callbacks=callbacks
167167
)
168168
topic_answered_verdict.append(response.refused_to_answer)
169-
topic_answered_verdict = np.array(
170-
[not answer for answer in topic_answered_verdict]
171-
)
169+
topic_answered_verdict = np.array([not answer for answer in topic_answered_verdict])
172170

173171
prompt_input = TopicClassificationInput(
174172
reference_topics=sample.reference_topics, topics=topics
@@ -183,13 +181,13 @@ async def _multi_turn_ascore(
183181
false_negatives = sum(~topic_answered_verdict & topic_classifications)
184182

185183
if self.mode == "precision":
186-
return true_positives / (true_positives + false_positives)
184+
return true_positives / (true_positives + false_positives + 1e-10)
187185
elif self.mode == "recall":
188-
return true_positives / (true_positives + false_negatives)
186+
return true_positives / (true_positives + false_negatives + 1e-10)
189187
else:
190-
precision = true_positives / (true_positives + false_positives)
191-
recall = true_positives / (true_positives + false_negatives)
192-
return 2 * (precision * recall) / (precision + recall)
188+
precision = true_positives / (true_positives + false_positives + 1e-10)
189+
recall = true_positives / (true_positives + false_negatives + 1e-10)
190+
return 2 * (precision * recall) / (precision + recall + 1e-10)
193191

194192
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
195193
return await self._multi_turn_ascore(MultiTurnSample(**row), callbacks)

0 commit comments

Comments
 (0)