Skip to content

Commit 8ddf9fe

Browse files
authored
Initialise all metics (#21)
* init all metrics * change metric imports
1 parent bc9d645 commit 8ddf9fe

File tree

5 files changed

+43
-44
lines changed

5 files changed

+43
-44
lines changed

ragas/metrics/__init__.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
from ragas.metrics.base import Evaluation, Metric
2-
from ragas.metrics.factual import EntailmentScore
3-
from ragas.metrics.similarity import SBERTScore
4-
from ragas.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL
2+
from ragas.metrics.factual import entailment_score, q_square
3+
from ragas.metrics.similarity import bert_score
4+
from ragas.metrics.simple import (
5+
bleu_score,
6+
edit_distance,
7+
edit_ratio,
8+
rouge1,
9+
rouge2,
10+
rougeL,
11+
)
512

613
__all__ = [
714
"Evaluation",
815
"Metric",
9-
"EntailmentScore",
10-
"SBERTScore",
11-
"BLUE",
12-
"EditDistance",
13-
"EditRatio",
14-
"RougeL",
15-
"Rouge1",
16-
"Rouge2",
16+
"entailment_score",
17+
"bert_score",
18+
"q_square",
19+
"bleu_score",
20+
"edit_distance",
21+
"edit_ratio",
22+
"rouge1",
23+
"rouge2",
24+
"rougeL",
1725
]

ragas/metrics/factual.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __post_init__(
218218

219219
@property
220220
def name(self):
221-
return "Q^2"
221+
return "Qsquare"
222222

223223
@property
224224
def is_batchable(self):
@@ -340,5 +340,5 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs):
340340
return scores
341341

342342

343-
ENTScore = EntailmentScore()
344-
Q2Score = Qsquare()
343+
entailment_score = EntailmentScore()
344+
q_square = Qsquare()

ragas/metrics/similarity.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
if t.TYPE_CHECKING:
1313
from torch import Tensor
1414

15-
SBERT_METRIC = t.Literal["cosine", "euclidean"]
15+
BERT_METRIC = t.Literal["cosine", "euclidean"]
1616

1717

1818
@dataclass
19-
class SBERTScore(Metric):
20-
similarity_metric: t.Literal[SBERT_METRIC] = "cosine"
19+
class BERTScore(Metric):
20+
similarity_metric: t.Literal[BERT_METRIC] = "cosine"
2121
model_path: str = "all-MiniLM-L6-v2"
2222
batch_size: int = 1000
2323

@@ -28,7 +28,7 @@ def __post_init__(self):
2828
def name(
2929
self,
3030
):
31-
return f"SBERT_{self.similarity_metric}"
31+
return f"BERTScore_{self.similarity_metric}"
3232

3333
@property
3434
def is_batchable(self):
@@ -64,4 +64,4 @@ def score(
6464
return score
6565

6666

67-
__all__ = ["SBERTScore"]
67+
bert_score = BERTScore()

ragas/metrics/simple.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def score(self, ground_truth: t.List[str], generated_text: t.List[str]):
9191
return score
9292

9393

94-
Rouge1 = ROUGE("rouge1")
95-
Rouge2 = ROUGE("rouge2")
96-
RougeL = ROUGE("rougeL")
97-
BLUE = BLEUScore()
98-
EditDistance = EditScore("distance")
99-
EditRatio = EditScore("ratio")
94+
rouge1 = ROUGE("rouge1")
95+
rouge2 = ROUGE("rouge2")
96+
rougeL = ROUGE("rougeL")
97+
bleu_score = BLEUScore()
98+
edit_distance = EditScore("distance")
99+
edit_ratio = EditScore("ratio")

tests/benchmarks/benchmark.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,27 @@
11
import typing as t
22

3-
from datasets import Dataset, load_dataset
3+
from datasets import Dataset, arrow_dataset, load_dataset
44
from torch.cuda import is_available
55
from tqdm import tqdm
66
from utils import print_table, timeit
77

8-
from ragas.metrics import (
9-
EditDistance,
10-
EditRatio,
11-
EntailmentScore,
12-
Evaluation,
13-
Rouge1,
14-
Rouge2,
15-
RougeL,
16-
SBERTScore,
17-
)
8+
from ragas.metrics import Evaluation, edit_distance, edit_ratio, rouge1, rouge2, rougeL
189

1910
DEVICE = "cuda" if is_available() else "cpu"
2011
BATCHES = [0, 1]
21-
# init metrics
22-
sbert_score = SBERTScore(similarity_metric="cosine")
23-
entail = EntailmentScore(max_length=512, device=DEVICE)
12+
2413
METRICS = {
25-
"Rouge1": Rouge1,
26-
"Rouge2": Rouge2,
27-
"RougeL": RougeL,
28-
"EditRatio": EditRatio,
29-
"EditDistance": EditDistance,
14+
"Rouge1": rouge1,
15+
"Rouge2": rouge2,
16+
"RougeL": rougeL,
17+
"EditRatio": edit_ratio,
18+
"EditDistance": edit_distance,
3019
# "SBERTScore": sbert_score,
3120
# "EntailmentScore": entail,
3221
}
3322
DS = load_dataset("explodinggradients/eli5-test", split="test_eli5")
23+
assert isinstance(DS, arrow_dataset.Dataset), "Not an arrow_dataset"
24+
DS = DS.select(range(100))
3425

3526

3627
def setup() -> t.Iterator[tuple[str, Evaluation, Dataset]]:

0 commit comments

Comments
 (0)