Skip to content

Commit 0523bfe

Browse files
committed
Add suport for Tokenized BLEU
* change "BLEU score" to "SacreBLEU" * accept tokenization function to compute tokenized BLEU * return tokenized BLEU as "BLEU score" for direct comparison with paperswithcode results * update docs
1 parent 13b526c commit 0523bfe

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

docs/docs/wmt.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ evaluator = WMTEvaluator(
7676

7777
The above will directly compare with the result of the paper when run on the server.
7878

79+
By default the evaluator computes a detokenized mixed-case SacreBLEU score.
80+
To get a tokenized BLEU score as well, during construction of the evaluator set
81+
a `tokenization: Callable[[str], str]` parameter to a function that tokenizes
82+
an input segment and returns segment with tokens separated by space, f.e.:
83+
84+
``` python
85+
def get_tokenization():
86+
mt = sacremoses.MosesTokenizer()
87+
def tokenize(sentence):
88+
return mt.tokenize(sentence, return_str=True)
89+
return tokenize
90+
91+
evaluator = WMTEvaluator(
92+
...,
93+
tokenization=get_tokenization()
94+
)
95+
```
96+
7997
Instead of parsing the dataset files by yourself you can access raw segments as strings:
8098

8199
``` python

sotabencheval/machine_translation/metrics.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
from bs4 import BeautifulSoup
33
from pathlib import Path
4-
from typing import Dict, List
4+
from typing import Dict, List, Callable
55
from collections import OrderedDict
66
from sacrebleu import corpus_bleu
77

@@ -10,12 +10,16 @@
1010

1111

1212
class TranslationMetrics:
13-
def __init__(self, source_dataset_path: Path, target_dataset_path):
13+
def __init__(self,
14+
source_dataset_path: Path,
15+
target_dataset_path: Path,
16+
tokenization: Callable[[str], str] = None):
1417
self._src_dataset_path = source_dataset_path
1518
self._dst_dataset_path = target_dataset_path
1619
self.answers = {}
1720
self.source_documents, self.source_segments = self._load_dataset(self._src_dataset_path)
1821
self._target_documents, self._target_segments = self._load_dataset(self._dst_dataset_path)
22+
self._tokenization = tokenization
1923
self._results = None
2024

2125
def _load_dataset(self, dataset_path):
@@ -41,12 +45,16 @@ def evaluate(self, ignore_missing=False):
4145
target_segments = {sid: text for sid, text in self._target_segments.items() if sid in keep}
4246
else:
4347
target_segments = self._target_segments
44-
references = [[target for target in target_segments.values()]]
4548
answers = [self.answers.get(sid, "") for sid in target_segments]
46-
bleu = corpus_bleu(answers, references)
47-
self._results = {
48-
'BLEU score': bleu.score
49-
}
49+
references = [target for target in target_segments.values()]
50+
bleu = corpus_bleu(answers, [references])
51+
self._results = {'SacreBLEU': bleu.score}
52+
53+
if self._tokenization is not None:
54+
tok_answers = [self._tokenization(answer) for answer in answers]
55+
tok_references = [self._tokenization(target) for target in references]
56+
tok_bleu = corpus_bleu(tok_answers, [tok_references], tokenize='none', force=True)
57+
self._results['BLEU score'] = tok_bleu.score
5058

5159
@property
5260
def has_data(self):

sotabencheval/machine_translation/wmt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sotabencheval.machine_translation.languages import Language
44
from sotabencheval.machine_translation.metrics import TranslationMetrics
55
from sotabencheval.utils import get_max_memory_allocated
6-
from typing import Dict
6+
from typing import Dict, Callable
77
from pathlib import Path
88
from enum import Enum
99
import time
@@ -33,7 +33,8 @@ def __init__(self,
3333
paper_arxiv_id: str = None,
3434
paper_pwc_id: str = None,
3535
paper_results: dict = None,
36-
model_description=None):
36+
model_description: str = None,
37+
tokenization: Callable[[str], str] = None):
3738
super().__init__(model_name, paper_arxiv_id, paper_pwc_id, paper_results, model_description)
3839
self.root = change_root_if_server(root=local_root,
3940
server_root=".data/nlp/wmt")
@@ -51,7 +52,7 @@ def __init__(self,
5152
self.source_dataset_path = Path(self.root) / source_dataset_filename
5253
self.target_dataset_path = Path(self.root) / target_dataset_filename
5354

54-
self.metrics = TranslationMetrics(self.source_dataset_path, self.target_dataset_path)
55+
self.metrics = TranslationMetrics(self.source_dataset_path, self.target_dataset_path, tokenization)
5556

5657
def _get_source_dataset_filename(self):
5758
if self.dataset == WMTDataset.News2014:

0 commit comments

Comments
 (0)