Skip to content

Commit 6d114e5

Browse files
This commit implements the F-beta score metric (#1543)
for the AnswerCorrectness class. The beta parameter is introduced to control the relative importance of recall and precision when calculating the score. Specifically: - beta > 1 places more emphasis on recall. - beta < 1 favors precision. - beta ==1 stands for the regular F1 score that can be interpreted as a harmonic mean of the precision and recall. Key Changes: The method _compute_statement_presence is updated to calculate the F-beta score based on true positives (TP), false positives (FP), and false negatives (FN). This ensures that we can balance between recall and precision, depending on the task's requirements, by tuning the beta value. source: https://scikit-learn.org/1.5/modules/generated/sklearn.metrics.fbeta_score.html --------- Co-authored-by: Shahules786 <[email protected]>
1 parent fd5e805 commit 6d114e5

File tree

5 files changed

+63
-53
lines changed

5 files changed

+63
-53
lines changed

src/ragas/metrics/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import inspect
2-
import sys
3-
41
from ragas.metrics._answer_correctness import AnswerCorrectness, answer_correctness
52
from ragas.metrics._answer_relevance import (
63
AnswerRelevancy,
@@ -120,10 +117,3 @@
120117
"MultiModalRelevance",
121118
"multimodal_relevance",
122119
]
123-
124-
current_module = sys.modules[__name__]
125-
ALL_METRICS = [
126-
obj
127-
for name, obj in inspect.getmembers(current_module)
128-
if name in __all__ and not inspect.isclass(obj) and not inspect.isbuiltin(obj)
129-
]

src/ragas/metrics/_answer_correctness.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SingleTurnMetric,
2222
get_segmenter,
2323
)
24+
from ragas.metrics.utils import fbeta_score
2425
from ragas.prompt import PydanticPrompt
2526
from ragas.run_config import RunConfig
2627

@@ -167,6 +168,7 @@ class AnswerCorrectness(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
167168
default_factory=LongFormAnswerPrompt
168169
)
169170
weights: list[float] = field(default_factory=lambda: [0.75, 0.25])
171+
beta: float = 1.0
170172
answer_similarity: t.Optional[AnswerSimilarity] = None
171173
sentence_segmenter: t.Optional[HasSegmentMethod] = None
172174
max_retries: int = 1
@@ -185,6 +187,11 @@ def __post_init__(self: t.Self):
185187
language = self.long_form_answer_prompt.language
186188
self.sentence_segmenter = get_segmenter(language=language, clean=False)
187189

190+
if type(self.beta) is not float:
191+
raise ValueError(
192+
"Beta must be a float. A beta > 1 gives more weight to recall, while beta < 1 favors precision."
193+
)
194+
188195
def init(self, run_config: RunConfig):
189196
super().init(run_config)
190197
if self.answer_similarity is None and self.weights[1] != 0:
@@ -198,7 +205,7 @@ def _compute_statement_presence(
198205
tp = len(prediction.TP)
199206
fp = len(prediction.FP)
200207
fn = len(prediction.FN)
201-
score = tp / (tp + 0.5 * (fp + fn)) if tp > 0 else 0
208+
score = fbeta_score(tp, fp, fn, self.beta)
202209
return score
203210

204211
async def _create_simplified_statements(

src/ragas/metrics/_factual_correctness.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SingleTurnMetric,
1717
get_segmenter,
1818
)
19+
from ragas.metrics.utils import fbeta_score
1920
from ragas.prompt import PydanticPrompt
2021

2122
if t.TYPE_CHECKING:
@@ -181,11 +182,32 @@ class ClaimDecompositionPrompt(
181182

182183
@dataclass
183184
class FactualCorrectness(MetricWithLLM, SingleTurnMetric):
185+
"""
186+
FactualCorrectness is a metric class that evaluates the factual correctness of responses
187+
generated by a language model. It uses claim decomposition and natural language inference (NLI)
188+
to verify the claims made in the responses against reference texts.
189+
190+
Attributes:
191+
name (str): The name of the metric, default is "factual_correctness".
192+
_required_columns (Dict[MetricType, Set[str]]): A dictionary specifying the required columns
193+
for each metric type. Default is {"SINGLE_TURN": {"response", "reference"}}.
194+
mode (Literal["precision", "recall", "f1"]): The mode of evaluation, can be "precision",
195+
"recall", or "f1". Default is "f1".
196+
beta (float): The beta value used for the F1 score calculation. A beta > 1 gives more weight
197+
to recall, while beta < 1 favors precision. Default is 1.0.
198+
atomicity (Literal["low", "high"]): The level of atomicity for claim decomposition. Default is "low".
199+
coverage (Literal["low", "high"]): The level of coverage for claim decomposition. Default is "low".
200+
claim_decomposition_prompt (PydanticPrompt): The prompt used for claim decomposition.
201+
nli_prompt (PydanticPrompt): The prompt used for natural language inference (NLI).
202+
203+
"""
204+
184205
name: str = "factual_correctness" # type: ignore
185206
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
186207
default_factory=lambda: {MetricType.SINGLE_TURN: {"response", "reference"}}
187208
)
188209
mode: t.Literal["precision", "recall", "f1"] = "f1"
210+
beta: float = 1.0
189211
atomicity: t.Literal["low", "high"] = "low"
190212
coverage: t.Literal["low", "high"] = "low"
191213
claim_decomposition_prompt: PydanticPrompt = ClaimDecompositionPrompt()
@@ -204,6 +226,11 @@ def __post_init__(self):
204226
)
205227
self.segmenter = get_segmenter(language="english")
206228

229+
if type(self.beta) is not float:
230+
raise ValueError(
231+
"Beta must be a float. A beta > 1 gives more weight to recall, while beta < 1 favors precision."
232+
)
233+
207234
async def decompose_claims(
208235
self, response: str, callbacks: Callbacks
209236
) -> t.List[str]:
@@ -253,21 +280,20 @@ async def _single_turn_ascore(
253280
else:
254281
response_reference = np.array([])
255282

256-
true_positives = sum(reference_response)
257-
false_positives = sum(~reference_response)
283+
tp = sum(reference_response)
284+
fp = sum(~reference_response)
258285
if self.mode != "precision":
259-
false_negatives = sum(~response_reference)
286+
fn = sum(~response_reference)
260287
else:
261-
false_negatives = 0
288+
fn = 0
289+
262290

263291
if self.mode == "precision":
264-
score = true_positives / (true_positives + false_positives + 1e-8)
292+
score = tp / (tp + fp + 1e-8)
265293
elif self.mode == "recall":
266-
score = true_positives / (true_positives + false_negatives + 1e-8)
294+
score = tp / (tp + fp + 1e-8)
267295
else:
268-
precision = true_positives / (true_positives + false_positives + 1e-8)
269-
recall = true_positives / (true_positives + false_negatives + 1e-8)
270-
score = 2 * (precision * recall) / (precision + recall + 1e-8)
296+
score = fbeta_score(tp, fp, fn, self.beta)
271297

272298
return np.round(score, 2)
273299

src/ragas/metrics/utils.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1-
from ragas.dataset_schema import EvaluationDataset
2-
from ragas.metrics import ALL_METRICS
3-
from ragas.metrics.base import Metric
4-
from ragas.validation import validate_required_columns
1+
def fbeta_score(tp, fp, fn, beta=1.0):
2+
if tp + fp == 0:
3+
precision = 0
4+
else:
5+
precision = tp / (tp + fp)
56

7+
if tp + fn == 0:
8+
recall = 0
9+
else:
10+
recall = tp / (tp + fn)
611

7-
def get_available_metrics(ds: EvaluationDataset) -> list[Metric]:
8-
"""
9-
Get the available metrics for the given dataset.
10-
E.g. if the dataset contains ("question", "answer", "contexts") columns,
11-
the available metrics are those that can be evaluated in [qa, qac, qc] mode.
12-
"""
13-
available_metrics = []
14-
for metric in ALL_METRICS:
15-
try:
16-
validate_required_columns(ds, [metric])
17-
available_metrics.append(metric)
18-
except ValueError:
19-
pass
12+
if precision == 0 and recall == 0:
13+
return 0.0
2014

21-
return available_metrics
15+
beta_squared = beta**2
16+
fbeta = (
17+
(1 + beta_squared)
18+
* (precision * recall)
19+
/ ((beta_squared * precision) + recall)
20+
)
21+
22+
return fbeta

tests/unit/test_metric.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,8 @@
11
import typing as t
22
from dataclasses import dataclass, field
33

4-
from ragas.dataset_schema import EvaluationDataset, SingleTurnSample
4+
from ragas.dataset_schema import SingleTurnSample
55
from ragas.metrics.base import MetricType
6-
from ragas.metrics.utils import get_available_metrics
7-
8-
9-
def test_get_available_metrics():
10-
sample1 = SingleTurnSample(user_input="What is X", response="Y")
11-
sample2 = SingleTurnSample(user_input="What is Z", response="W")
12-
ds = EvaluationDataset(samples=[sample1, sample2])
13-
14-
assert all(
15-
[
16-
m.required_columns["SINGLE_TURN"] == {"response", "user_input"}
17-
for m in get_available_metrics(ds)
18-
]
19-
), "All metrics should have required columns ('user_input', 'response')"
206

217

228
def test_single_turn_metric():

0 commit comments

Comments
 (0)