@@ -15,7 +15,41 @@ class WMTDataset(Enum):
1515
1616
1717class WMTEvaluator (BaseEvaluator ):
18+ """Evaluator for WMT Machine Translation benchmarks.
19+
20+ Examples:
21+ Evaluate a Transformer model from the fairseq repository on WMT2019 news test set:
22+
23+ .. code-block:: python
24+
25+ from sotabencheval.machine_translation import WMTEvaluator, WMTDataset, Language
26+ from tqdm import tqdm
27+ import torch
28+
29+ evaluator = WMTEvaluator(
30+ dataset=WMTDataset.News2019,
31+ source_lang=Language.English,
32+ target_lang=Language.German,
33+ local_root="data/nlp/wmt",
34+ model_name="Facebook-FAIR (single)",
35+ paper_arxiv_id="1907.06616"
36+ )
37+
38+ model = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model',
39+ force_reload=True, tokenizer='moses', bpe='fastbpe').cuda()
40+
41+ for sid, text in tqdm(evaluator.metrics.source_segments.items()):
42+ translated = model.translate(text)
43+ evaluator.add({sid: translated})
44+ if evaluator.cache_exists:
45+ break
46+
47+ evaluator.save()
48+ print(evaluator.results)
49+ """
50+
1851 task = "Machine Translation"
52+
1953 _datasets = {
2054 (WMTDataset .News2014 , Language .English , Language .German ),
2155 (WMTDataset .News2019 , Language .English , Language .German ),
@@ -35,6 +69,53 @@ def __init__(self,
3569 paper_results : dict = None ,
3670 model_description : str = None ,
3771 tokenization : Callable [[str ], str ] = None ):
72+ """
73+ Creates an evaluator for one of the WMT benchmarks.
74+
75+ :param dataset: Which dataset to evaluate on, f.e., WMTDataset.News2014.
76+ :param source_lang: Source language of the documents to translate.
77+ :param target_lang: Target language into which the documents are translated.
78+ :param local_root: Path to the directory where the dataset files are located locally.
79+ Ignored when run on sotabench server.
80+ :param source_dataset_filename: Local filename of the SGML file with the source documents.
81+ If None, the standard WMT filename is used, based on :param:`dataset`,
82+ :param:`source_lang` and :param:`target_lang`.
83+ Ignored when run on sotabench server.
84+ :param target_dataset_filename: Local filename of the SGML file with the reference documents.
85+ If None, the standard WMT filename is used, based on :param:`dataset`,
86+ :param:`source_lang` and :param:`target_lang`.
87+ Ignored when run on sotabench server.
88+ :param model_name: The name of the model from the
89+ paper - if you want to link your build to a model from a
90+ machine learning paper. See the WMT benchmarks page for model names,
91+ (f.e., https://sotabench.com/benchmarks/machine-translation-on-wmt2014-english-german)
92+ on the paper leaderboard or models yet to try tab.
93+ :param paper_arxiv_id: Optional linking to arXiv if you
94+ want to link to papers on the leaderboard; put in the
95+ corresponding paper's arXiv ID, e.g. '1907.06616'.
96+ :param paper_pwc_id: Optional linking to Papers With Code;
97+ put in the corresponding papers with code URL slug, e.g.
98+ 'facebook-fairs-wmt19-news-translation-task'
99+ :param paper_results: If the paper model you are reproducing
100+ does not have model results on sotabench.com, you can specify
101+ the paper results yourself through this argument, where keys
102+ are metric names, values are metric values. e.g:
103+
104+ {'SacreBLEU': 42.7, 'BLEU score': 43.1}.
105+
106+ Ensure that the metric names match those on the sotabench
107+ leaderboard - for WMT benchmarks it should be `SacreBLEU` for de-tokenized
108+ mix-cased BLEU score and `BLEU score` for tokenized BLEU.
109+ :param model_description: Optional model description.
110+ :param tokenization: An optional tokenization function to compute tokenized BLEU score.
111+ It takes a single string - a segment to tokenize, and returns a string with tokens
112+ separated by space, f.e.:
113+
114+ tokenization = lambda seg: seg.replace("'s", " 's").replace("-", " - ")
115+
116+ If None, only de-tokenized SacreBLEU score is reported.
117+ """
118+
38119 super ().__init__ (model_name , paper_arxiv_id , paper_pwc_id , paper_results , model_description )
39120 self .root = change_root_if_server (root = local_root ,
40121 server_root = ".data/nlp/wmt" )
@@ -78,8 +159,27 @@ def _get_dataset_name(self):
78159 ds_names = {WMTDataset .News2014 : "WMT2014" , WMTDataset .News2019 : "WMT2019" }
79160 return "{0} {1}-{2}" .format (ds_names .get (self .dataset ), self .source_lang .fullname , self .target_lang .fullname )
80161
81-
82162 def add (self , answers : Dict [str , str ]):
163+ """
164+ Updates the evaluator with new results
165+
166+ :param answers: a dict where keys are source segments ids and values are translated segments
167+ (segment id is created by concatenating document id and the original segment id,
168+ separated by `#`.)
169+
170+ Examples:
171+ Update the evaluator with three results:
172+
173+ .. code-block:: python
174+
175+ my_evaluator.add({
176+ 'bbc.381790#1': 'Waliser AMs sorgen sich um "Aussehen wie Muppets"',
177+ 'bbc.381790#2': 'Unter einigen AMs herrscht Bestürzung über einen...',
178+ 'bbc.381790#3': 'Sie ist aufgrund von Plänen entstanden, den Namen...'
179+ })
180+
181+ .. seealso:: `sotabencheval.machine_translation.TranslationMetrics.source_segments`
182+ """
83183
84184 self .metrics .add (answers )
85185
@@ -91,9 +191,30 @@ def add(self, answers: Dict[str, str]):
91191 self .first_batch_processed = True
92192
93193 def reset (self ):
194+ """
195+ Removes already added translations
196+
197+ When checking if the model should be rerun on whole dataset it is first run on a smaller subset
198+ and the results are compared with values cached on sotabench server (the check is not performed
199+ when running locally.) Ideally, the smaller subset is just the first batch, so no additional
200+ computation is needed. However, for more complex multistage pipelines it maybe simpler to
201+ run a model twice - on a small dataset and (if necessary) on the full dataset. In that case
202+ :func:`reset` needs to be called before the second run so values from the first run are not reported.
203+
204+ .. seealso:: :func:`cache_exists`
205+ .. seealso:: :func:`reset_time`
206+ """
207+
94208 self .metrics .reset ()
95209
96210 def get_results (self ):
211+ """
212+ Gets the results for the evaluator. Empty string is assumed for segments for which in translation
213+ was provided.
214+
215+ :return: dict with `SacreBLEU` and `BLEU score`
216+ """
217+
97218 if self .cached_results :
98219 return self .results
99220 self .results = self .metrics .get_results ()
0 commit comments