Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from lm_eval import utils, metrics
from abc import abstractmethod

from lm_eval.metric_impls import comet as comet_impl


class LM(abc.ABC):
def __init__(self):
Expand Down Expand Up @@ -605,7 +607,7 @@ class PromptSourceTask(Task):
"""

CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"])
CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE", "SARI"])
CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE", "SARI", "COMET"])
SPLIT = None

def __init__(
Expand Down Expand Up @@ -750,6 +752,10 @@ def process_results(self, doc, results):
out = {**out, **rouge_scores}
elif metric == "SARI":
out["sari"] = metrics.sari(self.doc_to_rawtext(doc), pred, target)
elif metric == "COMET":
out["comet"] = comet_impl.comet_process_results(
self.doc_to_rawtext(doc), pred, target
)

# TODO: Wrap process results s.t. override impl do not
# override the save examples.
Expand Down Expand Up @@ -788,6 +794,9 @@ def higher_is_better(self):
out["rougeLsum_fmeasure"] = True
elif metric == "SARI":
out["sari"] = True
if metric == "COMET":
out["comet"] = True

return out

def aggregation(self):
Expand Down Expand Up @@ -816,6 +825,9 @@ def aggregation(self):
out["rougeLsum_fmeasure"] = mean
elif metric == "SARI":
out["sari"] = mean
if metric == "COMET":
out["comet"] = comet_impl.comet_aggregation

return out

def fewshot_examples(self, k, rnd):
Expand Down Expand Up @@ -952,6 +964,20 @@ def get_logging_info(self):


class TranslationTask(PromptSourceTask):
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
prompt=None,
save_examples=True,
):
super().__init__(data_dir, cache_dir, download_mode, prompt, save_examples)

# TODO: Add check that language is valid for COMET
# IF NOT, update the
if "COMET" not in self.prompt.metadata.metrics:
self.prompt.metadata.metrics.append("COMET")

# Language specific functions.
@classmethod
Expand Down Expand Up @@ -1025,6 +1051,13 @@ def process_results(self, doc, results):
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
elif metric == "COMET":
# NOTE: If you use COMET, you must implement doc_to_rawtext
# which here serves as the original doc without the prompt added
# to it.
out["comet"] = comet_impl.comet_process_results(
self.doc_to_rawtext(doc), pred, target
)

# TODO: Wrap process results s.t. override impl do not
# override the save examples.
Expand Down
60 changes: 60 additions & 0 deletions lm_eval/metric_impls/comet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
_CITATION = """
We use the implementation from https://github.com/Unbabel/COMET.

@inproceedings{stewart-etal-2020-comet,
title = "{COMET} - Deploying a New State-of-the-art {MT} Evaluation Metric in Production",
author = "Stewart, Craig and
Rei, Ricardo and
Farinha, Catarina and
Lavie, Alon",
booktitle = "Proceedings of the 14th Conference of the Association for Machine Translation in the Americas (Volume 2: User Track)",
month = oct,
year = "2020",
address = "Virtual",
publisher = "Association for Machine Translation in the Americas",
url = "https://aclanthology.org/2020.amta-user.4",
pages = "78--109",
}
"""

"""Comet is built on top of XLM-R which cover the following languages:

Afrikaans, Albanian, Amharic, Arabic, Armenian, Assamese, Azerbaijani, Basque, Belarusian, Bengali, Bengali Romanized, Bosnian, Breton, Bulgarian, Burmese, Burmese, Catalan, Chinese (Simplified), Chinese (Traditional), Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Hausa, Hebrew, Hindi, Hindi Romanized, Hungarian, Icelandic, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Macedonian, Malagasy, Malay, Malayalam, Marathi, Mongolian, Nepali, Norwegian, Oriya, Oromo, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Sanskri, Scottish, Gaelic, Serbian, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tamil, Tamil Romanized, Telugu, Telugu Romanized, Thai, Turkish, Ukrainian, Urdu, Urdu Romanized, Uyghur, Uzbek, Vietnamese, Welsh, Western, Frisian, Xhosa, Yiddish.
"""

"""NOTE: Promptsource does not support COMET, so you have to add COMET to the list
of metrics when your task is instantiated in order to use it without over-riding process_results.
"""

import torch

from comet import download_model, load_from_checkpoint


def comet_process_results(src, pred, ref):
"""Per instance."""
if isinstance(ref, list):
assert len(ref) == 1, "Comet expects a single reference."
# https://github.com/Unbabel/COMET/issues/20
# If we want to add support for this, we need to average across multiple instances.
ref = ref[0]
return {"src": src, "mt": pred, "ref": ref}


def comet_aggregation(data):
# While we could be predict the comet scores for each row, instead we do it once
# as a batch because it requires using a model, and it makes more sense to batch
# the operations.
# 1.79G (wmt20-comet-da) + 2.09G (xlm-roberta-large)
model_path = download_model("wmt20-comet-da")
model = load_from_checkpoint(model_path)
if torch.cuda.is_available():
gpus = 1
else:
gpus = 0
_, sys_score = model.predict(data, batch_size=8, gpus=gpus)
return {"comet": sys_score}


def comet_higher_is_better():
return {"comet": True}
7 changes: 5 additions & 2 deletions lm_eval/tasks/wmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ def test_docs(self):
def max_generation_length(self) -> typing.Optional[int]:
return 64

def doc_to_rawtext(self, doc):
source = self._get_src_ref_codes(self.prompt.name)[0]
return doc["translation"][source]

# WMT 2014

# WMT 2014
class WMT14Base(WMTBase):
DATASET_PATH = "wmt14"

Expand Down Expand Up @@ -77,5 +80,5 @@ def create_year_tasks(year_classes) -> typing.Dict[str, WMTBase]:
for task_class in year_classes:
benchmark = task_class.DATASET_PATH
lang_pair = task_class.DATASET_NAME.replace("-", "_")
tasks[f'{benchmark}_{lang_pair}'] = task_class
tasks[f"{benchmark}_{lang_pair}"] = task_class
return tasks
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
install_requires=[
"promptsource @ git+https://github.com/bigscience-workshop/promptsource@eval-hackathon",
"codecarbon",
"unbabel-comet",
"wrapt",
"nltk",
"jinja2",
Expand Down Expand Up @@ -51,5 +52,5 @@
dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
],
extras_require={'dev': [ 'pytest', 'black' ]}
extras_require={"dev": ["pytest", "black"]},
)