From 9635de6d7db37a2fcaee17f36dcc397d85f667cf Mon Sep 17 00:00:00 2001 From: cjlovering Date: Tue, 10 May 2022 14:41:54 -0400 Subject: [PATCH 1/2] Initial commit for new metric: `Comet` * Added dependency * Added wiring for using it by default in a translation task * (We should discuss the above -- i.e. is it desired?) * Uses the same framework as SARI -- exepcts a `doc_to_rawtext` to be implemented --- lm_eval/base.py | 35 ++++++++++++++++++++- lm_eval/metric_impls/comet.py | 59 +++++++++++++++++++++++++++++++++++ setup.py | 3 +- 3 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 lm_eval/metric_impls/comet.py diff --git a/lm_eval/base.py b/lm_eval/base.py index 66af7a94f47..ddcc320c3dc 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -21,6 +21,8 @@ from lm_eval import utils, metrics from abc import abstractmethod +from metric_impls import comet as comet_impl + class LM(abc.ABC): def __init__(self): @@ -605,7 +607,7 @@ class PromptSourceTask(Task): """ CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"]) - CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE", "SARI"]) + CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE", "SARI", "COMET"]) SPLIT = None def __init__( @@ -750,6 +752,10 @@ def process_results(self, doc, results): out = {**out, **rouge_scores} elif metric == "SARI": out["sari"] = metrics.sari(self.doc_to_rawtext(doc), pred, target) + elif metric == "COMET": + out["comet"] = comet_impl.comet_process_results( + self.doc_to_rawtext(doc), pred, target + ) # TODO: Wrap process results s.t. override impl do not # override the save examples. @@ -788,6 +794,9 @@ def higher_is_better(self): out["rougeLsum_fmeasure"] = True elif metric == "SARI": out["sari"] = True + if metric == "COMET": + out["comet"] = True + return out def aggregation(self): @@ -816,6 +825,9 @@ def aggregation(self): out["rougeLsum_fmeasure"] = mean elif metric == "SARI": out["sari"] = mean + if metric == "COMET": + out["comet"] = comet_impl.comet_aggregation + return out def fewshot_examples(self, k, rnd): @@ -952,6 +964,20 @@ def get_logging_info(self): class TranslationTask(PromptSourceTask): + def __init__( + self, + data_dir=None, + cache_dir=None, + download_mode=None, + prompt=None, + save_examples=True, + ): + super().__init__(data_dir, cache_dir, download_mode, prompt, save_examples) + + # TODO: Add check that language is valid for COMET + # IF NOT, update the + if "COMET" not in self.prompt.metadata.metrics: + self.prompt.metadata.metrics.append("COMET") # Language specific functions. @classmethod @@ -1025,6 +1051,13 @@ def process_results(self, doc, results): rouge_scores = utils.flatten(rouge_scores) # Merge all the rouge-type scores into the `out` dict. out = {**out, **rouge_scores} + elif metric == "COMET": + # NOTE: If you use COMET, you must implement doc_to_rawtext + # which here serves as the original doc without the prompt added + # to it. + out["comet"] = comet_impl.comet_process_results( + self.doc_to_rawtext(doc), pred, target + ) # TODO: Wrap process results s.t. override impl do not # override the save examples. diff --git a/lm_eval/metric_impls/comet.py b/lm_eval/metric_impls/comet.py new file mode 100644 index 00000000000..c55a3f1f7dd --- /dev/null +++ b/lm_eval/metric_impls/comet.py @@ -0,0 +1,59 @@ +_CITATION = """ +We use the implementation from https://github.com/Unbabel/COMET. + +@inproceedings{stewart-etal-2020-comet, + title = "{COMET} - Deploying a New State-of-the-art {MT} Evaluation Metric in Production", + author = "Stewart, Craig and + Rei, Ricardo and + Farinha, Catarina and + Lavie, Alon", + booktitle = "Proceedings of the 14th Conference of the Association for Machine Translation in the Americas (Volume 2: User Track)", + month = oct, + year = "2020", + address = "Virtual", + publisher = "Association for Machine Translation in the Americas", + url = "https://aclanthology.org/2020.amta-user.4", + pages = "78--109", +} +""" + +"""Comet is built on top of XLM-R which cover the following languages: + +Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Basque, Belarusian, Bengali, Bengali Romanized, Bosnian, Breton, Bulgarian, Burmese, Burmese, Catalan, Chinese (Simplified), Chinese (Traditional), Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Hausa, Hebrew, Hindi, Hindi Romanized, Hungarian, Icelandic, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Macedonian, Malagasy, Malay, Malayalam, Marathi, Mongolian, Nepali, Norwegian, Oriya, Oromo, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Sanskri, Scottish, Gaelic, Serbian, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tamil, Tamil Romanized, Telugu, Telugu Romanized, Thai, Turkish, Ukrainian, Urdu, Urdu Romanized, Uyghur, Uzbek, Vietnamese, Welsh, Western, Frisian, Xhosa, Yiddish. +""" + +"""NOTE: Promptsource does not support COMET, so you have to add COMET to the list +of metrics when your task is instantiated in order to use it without over-riding process_results. +""" + +import torch + +from comet import download_model, load_from_checkpoint + + +def comet_process_results(src, pred, ref): + """Per instance.""" + if isinstance(ref, list): + assert len(ref) == 1, "Comet expects a single reference." + # https://github.com/Unbabel/COMET/issues/20 + # If we want to add support for this, we need to average across multiple instances. + ref = ref[0] + return {"src": src, "mt": pred, "ref": ref} + + +def comet_aggregation(data): + # While we could be predict the comet scores for each row, instead we do it once + # as a batch because it requires using a model, and it makes more sense to batch + # the operations. + model_path = download_model("wmt20-comet-da") + model = load_from_checkpoint(model_path) + if torch.cuda.is_available(): + gpus = 1 + else: + gpus = 0 + _, sys_score = model.predict(data, batch_size=8, gpus=gpus) + return {"comet": sys_score} + + +def comet_higher_is_better(): + return {"comet": True} diff --git a/setup.py b/setup.py index 6babe063704..12cf08fad69 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ install_requires=[ "promptsource @ git+https://github.com/bigscience-workshop/promptsource@eval-hackathon", "codecarbon", + "unbabel-comet", "wrapt", "nltk", "jinja2", @@ -51,5 +52,5 @@ dependency_links=[ "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", ], - extras_require={'dev': [ 'pytest', 'black' ]} + extras_require={"dev": ["pytest", "black"]}, ) From 147f87eea7d04da2a1eaaf6462f31e99daa66955 Mon Sep 17 00:00:00 2001 From: cjlovering Date: Wed, 11 May 2022 19:03:25 -0400 Subject: [PATCH 2/2] Testing out comet --- lm_eval/base.py | 2 +- lm_eval/metric_impls/comet.py | 1 + lm_eval/tasks/wmt.py | 7 +++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lm_eval/base.py b/lm_eval/base.py index 19104b7e75f..73aaccdec39 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -21,7 +21,7 @@ from lm_eval import utils, metrics from abc import abstractmethod -from metric_impls import comet as comet_impl +from lm_eval.metric_impls import comet as comet_impl class LM(abc.ABC): diff --git a/lm_eval/metric_impls/comet.py b/lm_eval/metric_impls/comet.py index c55a3f1f7dd..ad15a33a798 100644 --- a/lm_eval/metric_impls/comet.py +++ b/lm_eval/metric_impls/comet.py @@ -45,6 +45,7 @@ def comet_aggregation(data): # While we could be predict the comet scores for each row, instead we do it once # as a batch because it requires using a model, and it makes more sense to batch # the operations. + # 1.79G (wmt20-comet-da) + 2.09G (xlm-roberta-large) model_path = download_model("wmt20-comet-da") model = load_from_checkpoint(model_path) if torch.cuda.is_available(): diff --git a/lm_eval/tasks/wmt.py b/lm_eval/tasks/wmt.py index 4424b6278bb..2b0a31ad894 100644 --- a/lm_eval/tasks/wmt.py +++ b/lm_eval/tasks/wmt.py @@ -34,9 +34,12 @@ def test_docs(self): def max_generation_length(self) -> typing.Optional[int]: return 64 + def doc_to_rawtext(self, doc): + source = self._get_src_ref_codes(self.prompt.name)[0] + return doc["translation"][source] -# WMT 2014 +# WMT 2014 class WMT14Base(WMTBase): DATASET_PATH = "wmt14" @@ -77,5 +80,5 @@ def create_year_tasks(year_classes) -> typing.Dict[str, WMTBase]: for task_class in year_classes: benchmark = task_class.DATASET_PATH lang_pair = task_class.DATASET_NAME.replace("-", "_") - tasks[f'{benchmark}_{lang_pair}'] = task_class + tasks[f"{benchmark}_{lang_pair}"] = task_class return tasks