|
26 | 26 | """ |
27 | 27 | import logging |
28 | 28 | import math |
| 29 | +from typing import Literal |
29 | 30 |
|
30 | 31 | import numpy as np |
31 | 32 | import sacrebleu |
@@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]): |
89 | 90 |
|
90 | 91 |
|
91 | 92 | class CorpusLevelTranslationMetric: |
92 | | - def __init__(self, metric_type: str): |
| 93 | + def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""): |
93 | 94 | """Stores the relevant parameters for a corpus level translation metric. |
94 | 95 |
|
95 | 96 | Args: |
96 | 97 | metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use. |
97 | 98 | """ |
98 | | - if metric_type == "bleu": |
99 | | - self.metric = sacrebleu.corpus_bleu |
100 | | - elif metric_type == "chrf": |
101 | | - self.metric = sacrebleu.corpus_chrf |
102 | | - elif metric_type == "ter": |
103 | | - self.metric = sacrebleu.corpus_ter |
| 99 | + self.metric_type = metric_type |
| 100 | + self.lang = lang |
| 101 | + |
| 102 | + def get_metric(self): |
| 103 | + if self.metric_type == "bleu": |
| 104 | + return sacrebleu.BLEU(trg_lang=self.lang) |
| 105 | + elif self.metric_type == "chrf": |
| 106 | + return sacrebleu.CHRF() |
| 107 | + elif self.metric_type == "ter": |
| 108 | + return sacrebleu.TER(asian_support=True if self.lang != "" else False) |
104 | 109 | else: |
105 | | - raise ValueError(f"Unknown corpus level translation metric type : {metric_type}") |
| 110 | + raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}") |
106 | 111 |
|
107 | 112 | def compute(self, items: list[GenerativeCorpusMetricInput]) -> float: |
108 | 113 | """Computes the metric score over all the corpus generated items, by using the sacrebleu implementation.""" |
| 114 | + metric = self.get_metric() |
109 | 115 | golds = [i.golds for i in items] |
110 | 116 | preds = [] |
111 | 117 | for i in items: |
112 | 118 | pred = as_list(i.preds) |
113 | 119 | if len(pred) > 1: |
114 | 120 | logger.info( |
115 | | - f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})." |
| 121 | + f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})." |
116 | 122 | ) |
117 | 123 | preds.append(pred[0]) |
118 | | - return float(self.metric(hypotheses=preds, references=golds).score) |
| 124 | + return float(metric.corpus_score(hypotheses=preds, references=golds).score) |
119 | 125 |
|
120 | 126 |
|
121 | 127 | class CorpusLevelPerplexityMetric: |
|
0 commit comments