Skip to content

Commit 86b66d8

Browse files
committed
Document WMTEvaluator
1 parent b1baed3 commit 86b66d8

File tree

1 file changed

+122
-1
lines changed
  • sotabencheval/machine_translation

1 file changed

+122
-1
lines changed

sotabencheval/machine_translation/wmt.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,41 @@ class WMTDataset(Enum):
1515

1616

1717
class WMTEvaluator(BaseEvaluator):
18+
"""Evaluator for WMT Machine Translation benchmarks.
19+
20+
Examples:
21+
Evaluate a Transformer model from the fairseq repository on WMT2019 news test set:
22+
23+
.. code-block:: python
24+
25+
from sotabencheval.machine_translation import WMTEvaluator, WMTDataset, Language
26+
from tqdm import tqdm
27+
import torch
28+
29+
evaluator = WMTEvaluator(
30+
dataset=WMTDataset.News2019,
31+
source_lang=Language.English,
32+
target_lang=Language.German,
33+
local_root="data/nlp/wmt",
34+
model_name="Facebook-FAIR (single)",
35+
paper_arxiv_id="1907.06616"
36+
)
37+
38+
model = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model',
39+
force_reload=True, tokenizer='moses', bpe='fastbpe').cuda()
40+
41+
for sid, text in tqdm(evaluator.metrics.source_segments.items()):
42+
translated = model.translate(text)
43+
evaluator.add({sid: translated})
44+
if evaluator.cache_exists:
45+
break
46+
47+
evaluator.save()
48+
print(evaluator.results)
49+
"""
50+
1851
task = "Machine Translation"
52+
1953
_datasets = {
2054
(WMTDataset.News2014, Language.English, Language.German),
2155
(WMTDataset.News2019, Language.English, Language.German),
@@ -35,6 +69,53 @@ def __init__(self,
3569
paper_results: dict = None,
3670
model_description: str = None,
3771
tokenization: Callable[[str], str] = None):
72+
"""
73+
Creates an evaluator for one of the WMT benchmarks.
74+
75+
:param dataset: Which dataset to evaluate on, f.e., WMTDataset.News2014.
76+
:param source_lang: Source language of the documents to translate.
77+
:param target_lang: Target language into which the documents are translated.
78+
:param local_root: Path to the directory where the dataset files are located locally.
79+
Ignored when run on sotabench server.
80+
:param source_dataset_filename: Local filename of the SGML file with the source documents.
81+
If None, the standard WMT filename is used, based on :param:`dataset`,
82+
:param:`source_lang` and :param:`target_lang`.
83+
Ignored when run on sotabench server.
84+
:param target_dataset_filename: Local filename of the SGML file with the reference documents.
85+
If None, the standard WMT filename is used, based on :param:`dataset`,
86+
:param:`source_lang` and :param:`target_lang`.
87+
Ignored when run on sotabench server.
88+
:param model_name: The name of the model from the
89+
paper - if you want to link your build to a model from a
90+
machine learning paper. See the WMT benchmarks page for model names,
91+
(f.e., https://sotabench.com/benchmarks/machine-translation-on-wmt2014-english-german)
92+
on the paper leaderboard or models yet to try tab.
93+
:param paper_arxiv_id: Optional linking to arXiv if you
94+
want to link to papers on the leaderboard; put in the
95+
corresponding paper's arXiv ID, e.g. '1907.06616'.
96+
:param paper_pwc_id: Optional linking to Papers With Code;
97+
put in the corresponding papers with code URL slug, e.g.
98+
'facebook-fairs-wmt19-news-translation-task'
99+
:param paper_results: If the paper model you are reproducing
100+
does not have model results on sotabench.com, you can specify
101+
the paper results yourself through this argument, where keys
102+
are metric names, values are metric values. e.g:
103+
104+
{'SacreBLEU': 42.7, 'BLEU score': 43.1}.
105+
106+
Ensure that the metric names match those on the sotabench
107+
leaderboard - for WMT benchmarks it should be `SacreBLEU` for de-tokenized
108+
mix-cased BLEU score and `BLEU score` for tokenized BLEU.
109+
:param model_description: Optional model description.
110+
:param tokenization: An optional tokenization function to compute tokenized BLEU score.
111+
It takes a single string - a segment to tokenize, and returns a string with tokens
112+
separated by space, f.e.:
113+
114+
tokenization = lambda seg: seg.replace("'s", " 's").replace("-", " - ")
115+
116+
If None, only de-tokenized SacreBLEU score is reported.
117+
"""
118+
38119
super().__init__(model_name, paper_arxiv_id, paper_pwc_id, paper_results, model_description)
39120
self.root = change_root_if_server(root=local_root,
40121
server_root=".data/nlp/wmt")
@@ -78,8 +159,27 @@ def _get_dataset_name(self):
78159
ds_names = {WMTDataset.News2014: "WMT2014", WMTDataset.News2019: "WMT2019"}
79160
return "{0} {1}-{2}".format(ds_names.get(self.dataset), self.source_lang.fullname, self.target_lang.fullname)
80161

81-
82162
def add(self, answers: Dict[str, str]):
163+
"""
164+
Updates the evaluator with new results
165+
166+
:param answers: a dict where keys are source segments ids and values are translated segments
167+
(segment id is created by concatenating document id and the original segment id,
168+
separated by `#`.)
169+
170+
Examples:
171+
Update the evaluator with three results:
172+
173+
.. code-block:: python
174+
175+
my_evaluator.add({
176+
'bbc.381790#1': 'Waliser AMs sorgen sich um "Aussehen wie Muppets"',
177+
'bbc.381790#2': 'Unter einigen AMs herrscht Bestürzung über einen...',
178+
'bbc.381790#3': 'Sie ist aufgrund von Plänen entstanden, den Namen...'
179+
})
180+
181+
.. seealso:: `sotabencheval.machine_translation.TranslationMetrics.source_segments`
182+
"""
83183

84184
self.metrics.add(answers)
85185

@@ -91,9 +191,30 @@ def add(self, answers: Dict[str, str]):
91191
self.first_batch_processed = True
92192

93193
def reset(self):
194+
"""
195+
Removes already added translations
196+
197+
When checking if the model should be rerun on whole dataset it is first run on a smaller subset
198+
and the results are compared with values cached on sotabench server (the check is not performed
199+
when running locally.) Ideally, the smaller subset is just the first batch, so no additional
200+
computation is needed. However, for more complex multistage pipelines it maybe simpler to
201+
run a model twice - on a small dataset and (if necessary) on the full dataset. In that case
202+
:func:`reset` needs to be called before the second run so values from the first run are not reported.
203+
204+
.. seealso:: :func:`cache_exists`
205+
.. seealso:: :func:`reset_time`
206+
"""
207+
94208
self.metrics.reset()
95209

96210
def get_results(self):
211+
"""
212+
Gets the results for the evaluator. Empty string is assumed for segments for which in translation
213+
was provided.
214+
215+
:return: dict with `SacreBLEU` and `BLEU score`
216+
"""
217+
97218
if self.cached_results:
98219
return self.results
99220
self.results = self.metrics.get_results()

0 commit comments

Comments
 (0)