diff --git a/lm_eval/base.py b/lm_eval/base.py index 4caf8a12df6..73aaccdec39 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -21,6 +21,8 @@ from lm_eval import utils, metrics from abc import abstractmethod +from lm_eval.metric_impls import comet as comet_impl + class LM(abc.ABC): def __init__(self): @@ -605,7 +607,7 @@ class PromptSourceTask(Task): """ CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"]) - CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE", "SARI"]) + CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE", "SARI", "COMET"]) SPLIT = None def __init__( @@ -750,6 +752,10 @@ def process_results(self, doc, results): out = {**out, **rouge_scores} elif metric == "SARI": out["sari"] = metrics.sari(self.doc_to_rawtext(doc), pred, target) + elif metric == "COMET": + out["comet"] = comet_impl.comet_process_results( + self.doc_to_rawtext(doc), pred, target + ) # TODO: Wrap process results s.t. override impl do not # override the save examples. @@ -788,6 +794,9 @@ def higher_is_better(self): out["rougeLsum_fmeasure"] = True elif metric == "SARI": out["sari"] = True + if metric == "COMET": + out["comet"] = True + return out def aggregation(self): @@ -816,6 +825,9 @@ def aggregation(self): out["rougeLsum_fmeasure"] = mean elif metric == "SARI": out["sari"] = mean + if metric == "COMET": + out["comet"] = comet_impl.comet_aggregation + return out def fewshot_examples(self, k, rnd): @@ -952,6 +964,20 @@ def get_logging_info(self): class TranslationTask(PromptSourceTask): + def __init__( + self, + data_dir=None, + cache_dir=None, + download_mode=None, + prompt=None, + save_examples=True, + ): + super().__init__(data_dir, cache_dir, download_mode, prompt, save_examples) + + # TODO: Add check that language is valid for COMET + # IF NOT, update the + if "COMET" not in self.prompt.metadata.metrics: + self.prompt.metadata.metrics.append("COMET") # Language specific functions. @classmethod @@ -1025,6 +1051,13 @@ def process_results(self, doc, results): rouge_scores = utils.flatten(rouge_scores) # Merge all the rouge-type scores into the `out` dict. out = {**out, **rouge_scores} + elif metric == "COMET": + # NOTE: If you use COMET, you must implement doc_to_rawtext + # which here serves as the original doc without the prompt added + # to it. + out["comet"] = comet_impl.comet_process_results( + self.doc_to_rawtext(doc), pred, target + ) # TODO: Wrap process results s.t. override impl do not # override the save examples. diff --git a/lm_eval/metric_impls/comet.py b/lm_eval/metric_impls/comet.py new file mode 100644 index 00000000000..ad15a33a798 --- /dev/null +++ b/lm_eval/metric_impls/comet.py @@ -0,0 +1,60 @@ +_CITATION = """ +We use the implementation from https://github.com/Unbabel/COMET. + +@inproceedings{stewart-etal-2020-comet, + title = "{COMET} - Deploying a New State-of-the-art {MT} Evaluation Metric in Production", + author = "Stewart, Craig and + Rei, Ricardo and + Farinha, Catarina and + Lavie, Alon", + booktitle = "Proceedings of the 14th Conference of the Association for Machine Translation in the Americas (Volume 2: User Track)", + month = oct, + year = "2020", + address = "Virtual", + publisher = "Association for Machine Translation in the Americas", + url = "https://aclanthology.org/2020.amta-user.4", + pages = "78--109", +} +""" + +"""Comet is built on top of XLM-R which cover the following languages: + +Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Basque, Belarusian, Bengali, Bengali Romanized, Bosnian, Breton, Bulgarian, Burmese, Burmese, Catalan, Chinese (Simplified), Chinese (Traditional), Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Hausa, Hebrew, Hindi, Hindi Romanized, Hungarian, Icelandic, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Macedonian, Malagasy, Malay, Malayalam, Marathi, Mongolian, Nepali, Norwegian, Oriya, Oromo, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Sanskri, Scottish, Gaelic, Serbian, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tamil, Tamil Romanized, Telugu, Telugu Romanized, Thai, Turkish, Ukrainian, Urdu, Urdu Romanized, Uyghur, Uzbek, Vietnamese, Welsh, Western, Frisian, Xhosa, Yiddish. +""" + +"""NOTE: Promptsource does not support COMET, so you have to add COMET to the list +of metrics when your task is instantiated in order to use it without over-riding process_results. +""" + +import torch + +from comet import download_model, load_from_checkpoint + + +def comet_process_results(src, pred, ref): + """Per instance.""" + if isinstance(ref, list): + assert len(ref) == 1, "Comet expects a single reference." + # https://github.com/Unbabel/COMET/issues/20 + # If we want to add support for this, we need to average across multiple instances. + ref = ref[0] + return {"src": src, "mt": pred, "ref": ref} + + +def comet_aggregation(data): + # While we could be predict the comet scores for each row, instead we do it once + # as a batch because it requires using a model, and it makes more sense to batch + # the operations. + # 1.79G (wmt20-comet-da) + 2.09G (xlm-roberta-large) + model_path = download_model("wmt20-comet-da") + model = load_from_checkpoint(model_path) + if torch.cuda.is_available(): + gpus = 1 + else: + gpus = 0 + _, sys_score = model.predict(data, batch_size=8, gpus=gpus) + return {"comet": sys_score} + + +def comet_higher_is_better(): + return {"comet": True} diff --git a/lm_eval/tasks/wmt.py b/lm_eval/tasks/wmt.py index 4424b6278bb..2b0a31ad894 100644 --- a/lm_eval/tasks/wmt.py +++ b/lm_eval/tasks/wmt.py @@ -34,9 +34,12 @@ def test_docs(self): def max_generation_length(self) -> typing.Optional[int]: return 64 + def doc_to_rawtext(self, doc): + source = self._get_src_ref_codes(self.prompt.name)[0] + return doc["translation"][source] -# WMT 2014 +# WMT 2014 class WMT14Base(WMTBase): DATASET_PATH = "wmt14" @@ -77,5 +80,5 @@ def create_year_tasks(year_classes) -> typing.Dict[str, WMTBase]: for task_class in year_classes: benchmark = task_class.DATASET_PATH lang_pair = task_class.DATASET_NAME.replace("-", "_") - tasks[f'{benchmark}_{lang_pair}'] = task_class + tasks[f"{benchmark}_{lang_pair}"] = task_class return tasks diff --git a/setup.py b/setup.py index 6babe063704..12cf08fad69 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ install_requires=[ "promptsource @ git+https://github.com/bigscience-workshop/promptsource@eval-hackathon", "codecarbon", + "unbabel-comet", "wrapt", "nltk", "jinja2", @@ -51,5 +52,5 @@ dependency_links=[ "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", ], - extras_require={'dev': [ 'pytest', 'black' ]} + extras_require={"dev": ["pytest", "black"]}, )