Skip to content

Commit be18ae5

Browse files
authored
feat: add asian language support to CorpusLevelTranslationMetric (#479)
* feat: add asian language support to CorpusLevelTranslationMetric * fix: ci
1 parent fee2ec3 commit be18ae5

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

src/lighteval/metrics/metrics_corpus.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"""
2727
import logging
2828
import math
29+
from typing import Literal
2930

3031
import numpy as np
3132
import sacrebleu
@@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]):
8990

9091

9192
class CorpusLevelTranslationMetric:
92-
def __init__(self, metric_type: str):
93+
def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
9394
"""Stores the relevant parameters for a corpus level translation metric.
9495
9596
Args:
9697
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
9798
"""
98-
if metric_type == "bleu":
99-
self.metric = sacrebleu.corpus_bleu
100-
elif metric_type == "chrf":
101-
self.metric = sacrebleu.corpus_chrf
102-
elif metric_type == "ter":
103-
self.metric = sacrebleu.corpus_ter
99+
self.metric_type = metric_type
100+
self.lang = lang
101+
102+
def get_metric(self):
103+
if self.metric_type == "bleu":
104+
return sacrebleu.BLEU(trg_lang=self.lang)
105+
elif self.metric_type == "chrf":
106+
return sacrebleu.CHRF()
107+
elif self.metric_type == "ter":
108+
return sacrebleu.TER(asian_support=True if self.lang != "" else False)
104109
else:
105-
raise ValueError(f"Unknown corpus level translation metric type : {metric_type}")
110+
raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}")
106111

107112
def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
108113
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
114+
metric = self.get_metric()
109115
golds = [i.golds for i in items]
110116
preds = []
111117
for i in items:
112118
pred = as_list(i.preds)
113119
if len(pred) > 1:
114120
logger.info(
115-
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
121+
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
116122
)
117123
preds.append(pred[0])
118-
return float(self.metric(hypotheses=preds, references=golds).score)
124+
return float(metric.corpus_score(hypotheses=preds, references=golds).score)
119125

120126

121127
class CorpusLevelPerplexityMetric:

0 commit comments

Comments
 (0)