diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 3e0b45121..0112ef4f1 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -25,7 +25,9 @@ import numpy as np +from lighteval.metrics.metrics_corpus import CorpusLevelTranslationMetric from lighteval.metrics.metrics_sample import ( + BLEU, ExactMatches, F1_score, LoglikelihoodAcc, @@ -38,6 +40,7 @@ LogProbTokenNorm, get_multilingual_normalizer, ) +from lighteval.metrics.sample_preparator import GenerativePreparator from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401 ExprExtractionConfig, ExtractionTarget, @@ -47,7 +50,7 @@ get_extraction_regexes, ) from lighteval.metrics.utils.math_comparison import compare_gold_target -from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric +from lighteval.metrics.utils.metric_utils import CorpusLevelMetric, MetricCategory, MetricUseCase, SampleLevelMetric from lighteval.tasks.requests import Doc from lighteval.utils.language import Language from lighteval.utils.timeout import timeout @@ -122,7 +125,10 @@ def probability_metric( def multilingual_quasi_f1_score_metric( - language: Language, aggregation_function: Callable[[list[float]], float] = max + language: Language, + aggregation_function: Callable[[list[float]], float] = max, + normalize_gold: Callable[[str], str] | None = None, + normalize_pred: Callable[[str], str] | None = None, ) -> SampleLevelMetric: """ Creates a language-aware F1 score metric, which returns the F1 score. @@ -130,18 +136,23 @@ def multilingual_quasi_f1_score_metric( Args: language: The language of the samples. aggregation_function: Aggregation samples to use when multiple golds are present. + normalize_gold: Normalization function for gold answers. + normalize_pred: Normalization function for predictions. Returns: F1 score metric. """ metric_name = f"f1_{language.value}" - multilang_normalizer = get_multilingual_normalizer(language) + base_normalizer = get_multilingual_normalizer(language) + gold_normalizer = (lambda x: base_normalizer(normalize_gold(x))) if normalize_gold is not None else base_normalizer + pred_normalizer = (lambda x: base_normalizer(normalize_pred(x))) if normalize_pred is not None else base_normalizer + return SampleLevelMetric( metric_name=metric_name, sample_level_fn=F1_score( - normalize_gold=multilang_normalizer, - normalize_pred=multilang_normalizer, + normalize_gold=gold_normalizer, + normalize_pred=pred_normalizer, aggregation_function=aggregation_function, ).compute, category=MetricCategory.GENERATIVE, @@ -155,6 +166,8 @@ def multilingual_quasi_exact_match_metric( language: Language, match_type: Literal["prefix", "suffix", "full"] = "full", aggregation_function: Callable[[list[float]], float] = max, + normalize_gold: Callable[[str], str] | None = None, + normalize_pred: Callable[[str], str] | None = None, ) -> SampleLevelMetric: """ Creates a language-aware exact match metric, which returns the exact match score @@ -165,16 +178,21 @@ def multilingual_quasi_exact_match_metric( - "suffix": Suffixes must match - "full": Full strings must match aggregation_function: Aggregation samples to use when multiple golds are present. + normalize_gold: Normalization function for gold answers. + normalize_pred: Normalization function for predictions. Returns: Exact match metric. """ metric_name = f"exact_match_{language.value}_{match_type}" - multilang_normalizer = get_multilingual_normalizer(language) + base_normalizer = get_multilingual_normalizer(language) + gold_normalizer = (lambda x: base_normalizer(normalize_gold(x))) if normalize_gold is not None else base_normalizer + pred_normalizer = (lambda x: base_normalizer(normalize_pred(x))) if normalize_pred is not None else base_normalizer + return SampleLevelMetric( metric_name=metric_name, sample_level_fn=ExactMatches( - normalize_gold=multilang_normalizer, - normalize_pred=multilang_normalizer, + normalize_gold=gold_normalizer, + normalize_pred=pred_normalizer, aggregation_function=aggregation_function, type_exact_match=match_type, ).compute, @@ -185,6 +203,38 @@ def multilingual_quasi_exact_match_metric( ) +def translation_metric( + metric_name: Literal["bleu", "bleu_1", "bleu_4", "chrf", "chrf++"], + normalize_pred: Callable[[str], str] | None = None, + normalize_gold: Callable[[str], str] | None = None, +) -> CorpusLevelMetric | SampleLevelMetric: + """ + Creates a translation metric, which returns the translation score. + """ + if metric_name.startswith("bleu_"): + return SampleLevelMetric( + metric_name=metric_name, + sample_level_fn=BLEU( + n_gram=int(metric_name.split("_")[1]), normalize_pred=normalize_pred, normalize_gold=normalize_gold + ).compute, + category=MetricCategory.GENERATIVE, + use_case=MetricUseCase.TRANSLATION, + corpus_level_fn=np.mean, + higher_is_better=True, + ) + else: + return CorpusLevelMetric( + metric_name=metric_name, + sample_level_fn=GenerativePreparator().prepare, + category=MetricCategory.GENERATIVE, + use_case=MetricUseCase.TRANSLATION, + corpus_level_fn=CorpusLevelTranslationMetric( + metric_name, normalize_pred=normalize_pred, normalize_gold=normalize_gold + ).compute, # type: ignore + higher_is_better=True, + ) + + def multilingual_extractive_match_metric( language: Language = Language.ENGLISH, gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 030725a53..b0d488f8f 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -27,7 +27,7 @@ import logging import math -from typing import Literal +from typing import Callable, Literal import numpy as np import sacrebleu @@ -91,7 +91,13 @@ def compute(self, items: list[LogprobCorpusMetricInput]): class CorpusLevelTranslationMetric: - def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""): + def __init__( + self, + metric_type: Literal["bleu", "chrf", "chrf++", "ter"], + lang: Literal["zh", "ja", "ko", ""] = "", + normalize_pred: Callable[[str], str] | None = None, + normalize_gold: Callable[[str], str] | None = None, + ): """Stores the relevant parameters for a corpus level translation metric. Args: @@ -99,6 +105,8 @@ def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""): """ self.metric_type = metric_type self.lang = lang + self.normalize_pred = normalize_pred if normalize_pred is not None else lambda x: x + self.normalize_gold = normalize_gold if normalize_gold is not None else lambda x: x def get_metric(self): if self.metric_type == "bleu": @@ -115,7 +123,7 @@ def get_metric(self): def compute(self, items: list[GenerativeCorpusMetricInput]) -> float: """Computes the metric score over all the corpus generated items, by using the sacrebleu implementation.""" metric = self.get_metric() - golds = [i.golds for i in items] + golds = [[self.normalize_gold(gold) for gold in i.golds] for i in items] preds = [] for i in items: pred = as_list(i.preds) @@ -123,7 +131,7 @@ def compute(self, items: list[GenerativeCorpusMetricInput]) -> float: logger.info( f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})." ) - preds.append(pred[0]) + preds.append(self.normalize_pred(pred[0])) return float(metric.corpus_score(hypotheses=preds, references=golds).score) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 984e2607a..795aa0918 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -744,7 +744,12 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float: class BLEU: - def __init__(self, n_gram: int): + def __init__( + self, + n_gram: int, + normalize_pred: Callable[[str], str] | None = None, + normalize_gold: Callable[[str], str] | None = None, + ): """BLEU scorer class. Relies on `nltk`'s sentencebleu for scoring. TODO: Will have to move this to sacrebleu. @@ -752,6 +757,8 @@ def __init__(self, n_gram: int): n_gram (int): Number of n_grams to use for scoring. """ self.n_gram = n_gram + self.normalize_pred = normalize_pred + self.normalize_gold = normalize_gold def compute(self, golds: list[str], predictions: list[str], **kwargs): """Computes the sentence level BLEU between the golds and each prediction, then takes the average. @@ -763,6 +770,10 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs): Returns: float: Score over the current sample's items. """ + if self.normalize_pred: + predictions = [self.normalize_pred(p) for p in predictions] + if self.normalize_gold: + golds = [self.normalize_gold(g) for g in golds] return np.mean([self._bleu_score(golds, p) for p in predictions]) def _bleu_score(self, gold: list[str], pred: str) -> float: diff --git a/src/lighteval/metrics/utils/extractive_match_utils.py b/src/lighteval/metrics/utils/extractive_match_utils.py index b8c529d05..799b580de 100644 --- a/src/lighteval/metrics/utils/extractive_match_utils.py +++ b/src/lighteval/metrics/utils/extractive_match_utils.py @@ -89,6 +89,7 @@ class IndicesExtractionConfig: """ prefix_for_extraction: ChoicePrefix + bb_match_priority: int = -1 try_extract_without_anchor: bool = True @@ -340,6 +341,9 @@ def lazy_indices_regex( ] ) + if indices_config.bb_match_priority >= 0: + regexes.append((rf"\s*{indice_str_re}\s*", indices_config.bb_match_priority)) + return [(re.compile(pattern), priority) for pattern, priority in regexes] diff --git a/src/lighteval/tasks/multilingual/adapters.py b/src/lighteval/tasks/multilingual/adapters.py index 59b39fcd6..0ca18479c 100644 --- a/src/lighteval/tasks/multilingual/adapters.py +++ b/src/lighteval/tasks/multilingual/adapters.py @@ -29,6 +29,7 @@ from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.multilingual.utils.adapters_utils import ( extract_answers_from_string, + float_to_choice_string, multichoice_join, multichoice_to_single_choice, ) @@ -79,7 +80,7 @@ def thai_exams_adapter(line: dict) -> MCQInput | None: def alghafa_adapter(line: dict) -> MCQInput | None: answer_index = int(line["label"]) - choices_keys = [key for key in line.keys() if key not in ["query", "label", "__few_shots"]] + choices_keys = [key for key in line.keys() if key not in ["query", "label", "__index", "__few_shots"]] choices = [line[key] for key in choices_keys] return { "question": line["query"], @@ -298,3 +299,55 @@ def enem_adapter(lang: Language, line: dict) -> MCQInput | None: "choices": line["alternatives"], "gold_idx": LETTER_INDICES.index(line["label"]), } + + +CMM_MATH_ANSWER_RE = re.compile(r"([A-D])\.(.*?)(?=[A-D]\.|$)") + + +def cmm_math_adapter(line: dict) -> MCQInput | None: + """Adapter for CMM-Math dataset. + + Processes questions and options, handling cases where: + - Question ends with parentheses that need to be stripped + - Options are space-separated strings starting with A./B./C./D. + """ + # Strip ending parentheses from question + question = line["question"].strip().rstrip("( )") + + # Split options and store as dict with letter keys + choices = {} + for match in CMM_MATH_ANSWER_RE.finditer(line["options"]): + letter, choice = match.groups() + choices[letter] = choice.strip() + + try: + gold_idx = list(choices.keys()).index(line["answer"]) + except ValueError: + gold_idx = None + + # Validate we have enough options and answer + if len(choices) <= 1 or not line.get("answer") or gold_idx is None: + return None + + return {"question": question, "choices": list(choices.values()), "gold_idx": gold_idx} + + +def qazuntv2_adapter(line: dict) -> MCQInput | None: + gold_idx = LETTER_INDICES.index(line["answer"]) + choices = line["options"] + if gold_idx >= len(choices): + return None + return {"question": line["question"], "choices": choices, "gold_idx": gold_idx} + + +MGSM_COT_PREFIX_RE = re.compile( + r"\s*(ধাপে ধাপে উত্তর|Schritt-für-Schritt-Antwort|Step-by-Step Answer|Respuesta paso a paso|Réponse étape par étape|ステップごとの答え|Пошаговое решение|Jibu la Hatua kwa Hatua|దశలవారీగా సమాధానంi|คำตอบทีละขั้นตอน|逐步解答)\s*:\s*" +) +MGSM_QUESTION_RE = re.compile(r"\s*(প্রশ্ন|Frage|Question|Pregunta|Question|問題|Задача|Swali|ప్రశ్న|โจทย์|问题)\s*:\s*") + + +def mgsm_adapter(line: dict) -> QAInput | None: + question = MGSM_QUESTION_RE.sub("", line["question"]) + answer_cot = MGSM_COT_PREFIX_RE.sub("", line["answer"]) if line["answer"] else "" + answer_number = line["answer_number"] + return {"question": question, "few_shot_cot": answer_cot, "choices": [float_to_choice_string(answer_number)]} diff --git a/src/lighteval/tasks/multilingual/tasks.py b/src/lighteval/tasks/multilingual/tasks.py index e94a1e400..70e3466dc 100644 --- a/src/lighteval/tasks/multilingual/tasks.py +++ b/src/lighteval/tasks/multilingual/tasks.py @@ -28,8 +28,10 @@ from lighteval.metrics.dynamic_metrics import ( loglikelihood_acc_metric, + multilingual_extractive_match_metric, multilingual_quasi_exact_match_metric, multilingual_quasi_f1_score_metric, + translation_metric, ) from lighteval.metrics.metrics import Metrics from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm @@ -39,25 +41,36 @@ agieval_adapter, alghafa_adapter, ceval_adapter, - enem_adapter, + cmm_math_adapter, get_m3exam_adapter, get_mkqa_adapter, + mgsm_adapter, + qazuntv2_adapter, sciqa_adapter, thai_exams_adapter, winogrand_adapter, xcodah_adapter, ) -from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation, normalize_subset +from lighteval.tasks.multilingual.utils.adapters_utils import float_to_choice_string +from lighteval.tasks.multilingual.utils.task_utils import ( + get_cot_answer_normalization, + get_cot_generaion_size, + get_metrics_for_mcq_formulation, + get_stop_sequence, + normalize_subset, +) from lighteval.tasks.templates.boolq import get_boolq_prompt_function from lighteval.tasks.templates.continuation import get_continuation_prompt_function from lighteval.tasks.templates.copa import get_copa_prompt_function from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function +from lighteval.tasks.templates.math_qa import get_math_qa_prompt_function from lighteval.tasks.templates.multichoice import get_mcq_prompt_function from lighteval.tasks.templates.nli import get_nli_prompt_function from lighteval.tasks.templates.qa import get_qa_prompt_function from lighteval.tasks.templates.translation import get_translation_prompt_function from lighteval.tasks.templates.utils.formulation import ( CFFormulation, + Formulation, HybridFormulation, MCFFormulation, ) @@ -65,7 +78,23 @@ from lighteval.utils.language import Language, iso_639_3_ind_to_iso_639_3_macro, manage_duplicate_language_codes +# Philospohy of formulations: +# 1. For early-stage pretrained model, we recommend using CF formulation with few-shots (task_cf_native), this allows to get reasonable signal even at this stage +# 2. For later stage, we recommend using MCF formulation with few-shots (task_mcf_native), as models at this point should be able to do MCF formulation +# 3. For post-trained models, we recommend using MCF formulation without few-shots with cot (task_mcf_cot_native), this allows the best match to their real usage and they should be capable to +# follow expected format + +# Similarly for generative tasks, we recommend using non-cot variants for all pre-trained models, and cot variants for post-trained models + + TASKS_TABLE = [] +DEFAULT_FORMULATIONS: list[Formulation] = [ + MCFFormulation(), + MCFFormulation(cot=True), + MCFFormulation("NativeLetters"), + MCFFormulation("NativeLetters", cot=True), + CFFormulation(), +] # ------------------------------- NLI Tasks ------------------------------- # # NLI (Natural Language Inference) tasks involve determining the logical relationship # between two given sentences: a premise and a hypothesis. The goal is to classify @@ -76,12 +105,41 @@ # The XNLI dataset is a multilingual variant of MultiNLI # https://aclanthology.org/D18-1269/ + + +def get_formulation_name(formulation: Formulation) -> str: + match formulation: + case MCFFormulation("NativeLetters") | HybridFormulation("NativeLetters"): + name = formulation.name.lower() + "_native" + case MCFFormulation("Numbers") | HybridFormulation("Numbers"): + name = formulation.name.lower() + "_numbers" + case _: + name = formulation.name.lower() + if formulation.cot: + name += "_cot" + return name + + +def get_mcf_task_name(base_name: str, language: Language | str, formulation: Formulation | None) -> str: + """Helper function to generate consistent task names.""" + formulation_name = f"_{get_formulation_name(formulation)}" if formulation else "" + + language_name = language.value if isinstance(language, Language) else language + return f"{base_name}_{language_name}{formulation_name}" + + +def get_generative_task_name(base_name: str, language: Language | str, cot: bool) -> str: + language_name = language.value if isinstance(language, Language) else language + return f"{base_name}_{language_name}{'_cot' if cot else ''}" + + xnli_tasks = [ LightevalTaskConfig( - name=f"xnli_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("xnli", language, formulation), suite=["lighteval"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), @@ -104,6 +162,7 @@ hf_subset=standardize_tag(language.value), evaluation_splits=["validation"], few_shots_split="train", + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ARABIC, @@ -124,19 +183,19 @@ Language.VIETNAMESE, Language.CHINESE, ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] - # Improvement on XNLI with better translation, from our experience models tend to # perform better on XNLI2.0 than XNLI # https://arxiv.org/abs/2301.06527 xnli2_tasks = [ LightevalTaskConfig( - name=f"xnli2.0_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("xnli2.0", language, formulation), suite=["lighteval"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), @@ -161,6 +220,7 @@ hf_subset="default", evaluation_splits=["train"], hf_avail_splits=["train"], + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ENGLISH, @@ -189,14 +249,14 @@ Language.ARABIC, # Theoretically also: Bhojpuri, Gujarati, Odiya ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # Another variant of XNLI, with emphasis on Indic languages # https://arxiv.org/abs/2204.08776 xnli_indic_tasks = [ LightevalTaskConfig( - name=f"indicnxnli_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("indicxnli", language, formulation), suite=["lighteval"], prompt_function=get_nli_prompt_function( language=language, @@ -215,14 +275,16 @@ hf_filter=lambda x: int(x["label"]) in [0, 2], evaluation_splits=["validation"], few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ASSAMESE, @@ -237,14 +299,14 @@ Language.TAMIL, Language.TELUGU, ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # African XNLI: African XNLI # From https://arxiv.org/abs/2406.03368. Human translated MMLU. afri_xnli_tasks = [ LightevalTaskConfig( - name=f"afri_xnli_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("afri_xnli", language, formulation), suite=("lighteval",), prompt_function=get_nli_prompt_function( language=language, @@ -262,14 +324,16 @@ hf_filter=lambda x: int(x["label"]) in [0, 2], evaluation_splits=("test",), few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.AMHARIC, @@ -290,7 +354,7 @@ Language.YORUBA, # Language.ZULU, ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # PAWS-X: A Cross-lingual Adversarial Dataset for Paraphrase Identification @@ -301,7 +365,7 @@ paws_x_tasks = [ LightevalTaskConfig( - name=f"pawsx_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("pawsx", language, formulation), suite=("lighteval",), prompt_function=get_nli_prompt_function( language=language, @@ -318,14 +382,16 @@ hf_subset=standardize_tag(language.value), evaluation_splits=("test",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.GERMAN, @@ -336,7 +402,7 @@ Language.KOREAN, Language.CHINESE, ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # Russian Commitment Bank (RCB) is a large-scale NLI dataset with Russian sentences, @@ -344,13 +410,12 @@ # https://arxiv.org/abs/2401.04531 rcb_tasks = [ LightevalTaskConfig( - name=f"rcb_{Language.RUSSIAN.value}_{formulation.name.lower()}", + name=get_mcf_task_name("rcb", Language.RUSSIAN, formulation), prompt_function=get_nli_prompt_function( language=Language.RUSSIAN, adapter=lambda line: { "premise": line["inputs"]["premise"], "hypothesis": line["inputs"]["hypothesis"], - # Since we ignore the neutral label "gold_idx": int(line["outputs"]) - 1, }, relations=["entailment", "contradiction"], @@ -359,20 +424,21 @@ suite=("lighteval",), hf_repo="ai-forever/MERA", hf_subset="rcb", - # Ignore neutral label hf_filter=lambda x: int(x["outputs"] or "0") in [1, 2], evaluation_splits=("train",), few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.RUSSIAN, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # Native Chinese NLI dataset based. @@ -380,13 +446,12 @@ # We find this benchmark to have really good signal compared to other Chinese NLI ocnli_tasks = [ LightevalTaskConfig( - name=f"ocnli_{Language.CHINESE.value}_{formulation.name.lower()}", + name=get_mcf_task_name("ocnli", Language.CHINESE, formulation), prompt_function=get_nli_prompt_function( language=Language.CHINESE, adapter=lambda line: { "premise": line["sentence1"], "hypothesis": line["sentence2"], - # Since we ignore the neutral label "gold_idx": {1: 0, 2: 1}[line["label"]], }, relations=["entailment", "contradiction"], @@ -395,33 +460,33 @@ suite=("lighteval",), hf_repo="clue/clue", hf_subset="ocnli", - # Only keep the positive and negative examples hf_filter=lambda x: int(x["label"]) in [1, 2], evaluation_splits=("validation",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.CHINESE, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # https://arxiv.org/abs/2004.05986 # Native Chinese NLI dataset based on MNLI approach (Machine Translated) cmnli_tasks = [ LightevalTaskConfig( - name=f"cmnli_{Language.CHINESE.value}_{formulation.name.lower()}", + name=get_mcf_task_name("cmnli", Language.CHINESE, formulation), prompt_function=get_nli_prompt_function( language=Language.CHINESE, adapter=lambda line: { "premise": line["sentence1"], "hypothesis": line["sentence2"], - # Since we ignore the neutral label "gold_idx": {"entailment": 0, "contradiction": 1}[line["label"]], }, relations=["entailment", "contradiction"], @@ -431,19 +496,20 @@ hf_repo="fenffef/cmnli", hf_subset="default", hf_filter=lambda x: x["label"] in ["entailment", "contradiction"], - # Only keep the positive and negative examples evaluation_splits=("validation",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.CHINESE, [ loglikelihood_acc_metric(normalization=None), loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -467,7 +533,7 @@ # XCOPA extends the original English COPA task to 11 typologically diverse languages. xcopa_tasks = [ LightevalTaskConfig( - name=f"xcopa_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("xcopa", language, formulation), suite=["lighteval"], prompt_function=get_copa_prompt_function( language, @@ -483,13 +549,15 @@ hf_subset=("copa_ext_ar" if language == Language.ARABIC else standardize_tag(language.value)), evaluation_splits=["test"], few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ARABIC, @@ -505,7 +573,7 @@ Language.HAITIAN, Language.QUECHUA, ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # IndicCOPA: COPA for Indic Languages @@ -514,7 +582,7 @@ # evaluating common sense reasoning in these languages. copa_indic_tasks = [ LightevalTaskConfig( - name=f"indicxcopa_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("indicxcopa", language, formulation), suite=["lighteval"], prompt_function=get_copa_prompt_function( language, @@ -533,14 +601,16 @@ hf_revision="d356ef19a4eb287e88a51d07a56b73ba88c7f188", evaluation_splits=["test"], hf_avail_splits=["test"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), trust_dataset=True, + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ASSAMESE, @@ -560,7 +630,7 @@ Language.URDU, # Optionally: Maithili, Santali, Sindhi, Konkani ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # PARus: Plausible Alternatives for Russian @@ -569,7 +639,7 @@ # It evaluates common sense reasoning and causal inference abilities in Russian language models. parus_tasks = [ LightevalTaskConfig( - name=f"parus_{Language.RUSSIAN.value}_{formulation.name.lower()}", + name=get_mcf_task_name("parus", Language.RUSSIAN, formulation), suite=["lighteval"], prompt_function=get_copa_prompt_function( language=Language.RUSSIAN, @@ -585,17 +655,21 @@ hf_subset="parus", evaluation_splits=["train"], few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.RUSSIAN, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] +# Rerun hellaswag tasks +# Add xcopa to thai TASKS_TABLE.extend([*xcopa_tasks, *copa_indic_tasks, *parus_tasks]) # ------------------------------- Hellaswag Tasks ------------------------------- # @@ -609,12 +683,12 @@ # It evaluates commonsense reasoning abilities across multiple languages. mlmm_hellaswag_tasks = [ LightevalTaskConfig( - name=f"mlmm_hellaswag_{lang.value}_{formulation.name.lower()}", + name=get_mcf_task_name("mlmm_hellaswag", lang, formulation), suite=["lighteval"], prompt_function=get_hellaswag_prompt_function( language=lang, adapter=lambda line: { - # We don't use activity_label as they are not available + # activity_label is only available in english, thus we don't use it "ctx_a": line["ctx_a"], "ctx_b": line["ctx_b"], "continuations": line["endings"], @@ -629,14 +703,16 @@ hf_revision="96ed8e0dfc6172dad1d3df338d7b8ba6c1ff9d83", evaluation_splits=["validation"], hf_avail_splits=["validation"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + lang, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), trust_dataset=True, + stop_sequence=get_stop_sequence(lang, formulation.cot), ) for lang in [ Language.ARABIC, @@ -673,7 +749,7 @@ Language.VIETNAMESE, Language.CHINESE, ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # Hellaswag Turkish @@ -685,13 +761,12 @@ # which would make it hard to read hellaswag_tur_tasks = [ LightevalTaskConfig( - name=f"community_hellaswag_{Language.TURKISH.value}_{formulation.name.lower()}", + name=get_mcf_task_name("community_hellaswag", Language.TURKISH, formulation), suite=["lighteval"], prompt_function=get_hellaswag_prompt_function( language=Language.TURKISH, adapter=lambda line: { - "ctx_a": line["ctx_a"], - "ctx_b": line["ctx_b"], + "ctx_a": line["ctx"], "continuations": line["endings"], "gold_idx": int(line["label"]), }, @@ -703,15 +778,17 @@ hf_subset="default", evaluation_splits=["validation"], hf_avail_splits=["validation"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.TURKISH, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] # Hellaswag Thai @@ -720,13 +797,13 @@ # for evaluating Thai language models on commonsense reasoning tasks. hellaswag_tha_tasks = [ LightevalTaskConfig( - name=f"community_hellaswag_{Language.THAI.value}_{formulation.name.lower()}", + name=get_mcf_task_name("community_hellaswag", Language.THAI, formulation), suite=["lighteval"], prompt_function=get_hellaswag_prompt_function( language=Language.THAI, adapter=lambda line: { - "ctx_a": line["ctx_a"], - "ctx_b": line["ctx_b"], + "activity_label": line["activity_label"], + "ctx_a": line["ctx"], "continuations": line["endings"], "gold_idx": int(line["label"]), }, @@ -737,72 +814,49 @@ hf_subset="default", evaluation_splits=["validation"], few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.THAI, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.THAI, formulation.cot), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] -] - -hellaswag_hin_tasks = [ - LightevalTaskConfig( - name=f"community_hellaswag_{Language.HINDI.value}_{formulation.name.lower()}", - suite=["lighteval"], - prompt_function=get_hellaswag_prompt_function( - language=Language.HINDI, - adapter=lambda line: { - "ctx_a": line["ctx_a"], - "continuations": line["endings"], - "gold_idx": int(line["label"]), - }, - formulation=formulation, - ), - hf_repo="ai4bharat/hellaswag-hi", - hf_filter=lambda line: all(len(choice.strip()) > 0 for choice in line["endings"]), - hf_subset="hi", - evaluation_splits=("validation",), - few_shots_split="validation", - metric=get_metrics_for_formulation( - formulation, - [ - loglikelihood_acc_metric(normalization=LogProbCharNorm()), - loglikelihood_acc_metric(normalization=LogProbTokenNorm()), - ], - ), - ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] hellaswag_tel_tasks = [ LightevalTaskConfig( - name=f"community_hellaswag_{Language.TELUGU.value}_{formulation.name.lower()}", + name=get_mcf_task_name("community_hellaswag", Language.TELUGU, formulation), suite=["lighteval"], prompt_function=get_hellaswag_prompt_function( language=Language.TELUGU, adapter=lambda line: { + # Activity label is only available in english, thus we don't use it "ctx_a": line["ctx_a"], - "continuations": line["endings"], + "continuations": [line["a"], line["b"], line["c"], line["d"]], "gold_idx": int(line["label"]), }, formulation=formulation, + wikihow_artifacts=[" [శీర్షిక]", " [హెడర్]", " [header]", " [Header]"], ), - hf_repo="LightFury9/hellaswag-telugu", + hf_repo="indiehackers/hellaswag-telugu-custom-2k", hf_subset="default", - evaluation_splits=("valid",), - few_shots_split="train", - metric=get_metrics_for_formulation( + evaluation_splits=("train",), + hf_avail_splits=("train",), + metric=get_metrics_for_mcq_formulation( formulation, + Language.TELUGU, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.TELUGU, formulation.cot), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -810,7 +864,6 @@ *mlmm_hellaswag_tasks, *hellaswag_tur_tasks, *hellaswag_tha_tasks, - *hellaswag_hin_tasks, *hellaswag_tel_tasks, ] ) @@ -825,7 +878,7 @@ # https://arxiv.org/abs/1910.11856 xquad_tasks = [ LightevalTaskConfig( - name=f"xquad_{language.value}", + name=get_generative_task_name("xquad", language, cot), prompt_function=get_qa_prompt_function( language, lambda line: { @@ -833,18 +886,21 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="google/xquad", hf_subset=f"xquad.{standardize_tag(language.value)}", evaluation_splits=("validation",), few_shots_split="validation", - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=( - multilingual_quasi_exact_match_metric(language, "prefix"), - multilingual_quasi_f1_score_metric(language), + multilingual_quasi_exact_match_metric( + language, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)), ), + stop_sequence=get_stop_sequence(language, cot), ) for language in [ Language.ARABIC, @@ -860,72 +916,13 @@ Language.VIETNAMESE, Language.CHINESE, ] + for cot in [False, True] ] -# GermanQuAD: High-quality German QA dataset with 13,722 questions -# https://arxiv.org/abs/2104.12741 -germanquad_tasks = [ - LightevalTaskConfig( - name=f"germanquad_{Language.GERMAN.value}", - prompt_function=get_qa_prompt_function( - Language.GERMAN, - lambda line: { - "question": line["question"], - "context": line["context"], - "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], - }, - ), - suite=("lighteval",), - hf_repo="deepset/germanquad", - hf_subset="plain_text", - trust_dataset=True, - hf_revision="fff05ceaf2ffbe5b65c7e0c57e678f7b7e1a0581", - hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]), - evaluation_splits=("test",), - few_shots_split="train", - generation_size=400, - stop_sequence=("\n",), - metric=( - multilingual_quasi_exact_match_metric(Language.GERMAN, "prefix"), - multilingual_quasi_f1_score_metric(Language.GERMAN), - ), - ) -] - - -# SQuAD-it: Italian translation of the SQuAD dataset -# https://github.com/crux82/squad-it -squad_it_tasks = [ - LightevalTaskConfig( - name=f"squad_{Language.ITALIAN.value}", - prompt_function=get_qa_prompt_function( - Language.ITALIAN, - lambda line: { - "question": line["question"], - "context": line["context"], - "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], - }, - ), - suite=("lighteval",), - hf_repo="crux82/squad_it", - hf_subset="default", - hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]), - evaluation_splits=("test",), - few_shots_split="train", - generation_size=400, - stop_sequence=("\n",), - metric=( - multilingual_quasi_exact_match_metric(Language.ITALIAN, "prefix"), - multilingual_quasi_f1_score_metric(Language.ITALIAN), - ), - ) -] - - # ThaiQA: A question answering dataset for the Thai language. thaiqa_tasks = [ LightevalTaskConfig( - name=f"thaiqa_{Language.THAI.value}", + name=get_generative_task_name("thaiqa", Language.THAI, cot), prompt_function=get_qa_prompt_function( Language.THAI, lambda line: { @@ -933,26 +930,30 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["answer"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="lighteval/thaiqa_squad_fixed", hf_subset="default", evaluation_splits=("train",), few_shots_split="validation", - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), + stop_sequence=get_stop_sequence(Language.THAI, cot), metric=( - multilingual_quasi_exact_match_metric(Language.THAI, "prefix"), - multilingual_quasi_f1_score_metric(Language.THAI), + multilingual_quasi_exact_match_metric( + Language.THAI, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.THAI, normalize_pred=get_cot_answer_normalization(cot)), ), ) + for cot in [False, True] ] # SberQuAD: A large-scale Russian reading comprehension dataset. # https://arxiv.org/abs/1912.09723 sber_squad_tasks = [ LightevalTaskConfig( - name=f"sber_squad_{Language.RUSSIAN.value}", + name=get_generative_task_name("sber_squad", Language.RUSSIAN, cot), prompt_function=get_qa_prompt_function( Language.RUSSIAN, lambda line: { @@ -960,6 +961,7 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="kuznetsoffandrey/sberquad", @@ -967,79 +969,22 @@ evaluation_splits=("validation",), few_shots_split="train", metric=( - multilingual_quasi_exact_match_metric(Language.RUSSIAN, "prefix"), - multilingual_quasi_f1_score_metric(Language.RUSSIAN), - ), - generation_size=400, - stop_sequence=("\n",), - ) -] - -# FaQuAD: A Portuguese Reading Comprehension Dataset -# https://arxiv.org/abs/2007.15671 -faquad_tasks = [ - LightevalTaskConfig( - name=f"faquad_{Language.PORTUGUESE.value}", - prompt_function=get_qa_prompt_function( - Language.PORTUGUESE, - lambda line: { - "question": line["question"], - "context": line["context"], - "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], - }, - ), - suite=("lighteval",), - hf_repo="eraldoluis/faquad", - hf_subset="plain_text", - trust_dataset=True, - hf_revision="205ba826a2282a4a5aa9bd3651e55ee4f2da1546", - hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]), - evaluation_splits=("validation",), - few_shots_split="train", - metric=( - multilingual_quasi_exact_match_metric(Language.PORTUGUESE, "prefix"), - multilingual_quasi_f1_score_metric(Language.PORTUGUESE), - ), - generation_size=400, - stop_sequence=("\n",), - ) -] - - -# SQuAD-es: Spanish translation of the Stanford Question Answering Dataset -# https://huggingface.co/datasets/ccasimiro/squad_es -squad_es_tasks = [ - LightevalTaskConfig( - name=f"squad_{Language.SPANISH.value}", - prompt_function=get_qa_prompt_function( - Language.SPANISH, - lambda line: { - "question": line["question"], - "context": line["context"], - "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], - }, - ), - suite=("lighteval",), - hf_repo="ccasimiro/squad_es", - hf_subset="v2.0.0", - hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]), - evaluation_splits=("validation",), - few_shots_split="train", - metric=( - multilingual_quasi_exact_match_metric(Language.SPANISH, "prefix"), - multilingual_quasi_f1_score_metric(Language.SPANISH), + multilingual_quasi_exact_match_metric( + Language.RUSSIAN, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.RUSSIAN, normalize_pred=get_cot_answer_normalization(cot)), ), - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), + stop_sequence=get_stop_sequence(Language.RUSSIAN, cot), ) + for cot in [False, True] ] - # ARCD: Arabic Reading Comprehension Dataset. # https://arxiv.org/pdf/1906.05394 arcd_tasks = [ LightevalTaskConfig( - name=f"arcd_{Language.ARABIC.value}", + name=get_generative_task_name("arcd", Language.ARABIC, cot), prompt_function=get_qa_prompt_function( Language.ARABIC, lambda line: { @@ -1047,6 +992,7 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="hsseinmz/arcd", @@ -1057,16 +1003,17 @@ multilingual_quasi_exact_match_metric(Language.ARABIC, "prefix"), multilingual_quasi_f1_score_metric(Language.ARABIC), ), - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), + stop_sequence=get_stop_sequence(Language.ARABIC, cot), ) + for cot in [False, True] ] # KenSwQuAD: A question answering dataset for Kenyan Swahili. # https://arxiv.org/abs/2205.02364 kenswquad_tasks = [ LightevalTaskConfig( - name=f"kenswquad_{Language.SWAHILI.value}", + name=get_generative_task_name("kenswquad", Language.SWAHILI, cot), prompt_function=get_qa_prompt_function( Language.SWAHILI, lambda line: { @@ -1074,6 +1021,7 @@ "context": line["context"], "choices": [line["answer"]], }, + cot=cot, ), suite=("lighteval",), hf_repo="lighteval/KenSwQuAD", @@ -1081,19 +1029,22 @@ evaluation_splits=("test",), few_shots_split="validation", metric=( - multilingual_quasi_exact_match_metric(Language.SWAHILI, "prefix"), - multilingual_quasi_f1_score_metric(Language.SWAHILI), + multilingual_quasi_exact_match_metric( + Language.SWAHILI, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.SWAHILI, normalize_pred=get_cot_answer_normalization(cot)), ), - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), + stop_sequence=get_stop_sequence(Language.SWAHILI, cot), ) + for cot in [False, True] ] # ChineseSquad: A reading comprehension dataset for Chinese. # https://github.com/pluto-junzeng/ChineseSquad chinese_squad_tasks = [ LightevalTaskConfig( - name=f"chinese_squad_{Language.CHINESE.value}", + name=get_generative_task_name("chinese_squad", Language.CHINESE, cot), prompt_function=get_qa_prompt_function( Language.CHINESE, lambda line: { @@ -1101,6 +1052,7 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="lighteval/ChineseSquad", @@ -1108,19 +1060,22 @@ evaluation_splits=("validation",), few_shots_split="train", metric=( - multilingual_quasi_exact_match_metric(Language.CHINESE, "prefix"), - multilingual_quasi_f1_score_metric(Language.CHINESE), + multilingual_quasi_exact_match_metric( + Language.CHINESE, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.CHINESE, normalize_pred=get_cot_answer_normalization(cot)), ), - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), + stop_sequence=get_stop_sequence(Language.CHINESE, cot), ) + for cot in [False, True] ] # CMRC 2018: A span-extraction machine reading comprehension dataset for Chinese. # https://arxiv.org/abs/1810.07366 cmrc2018_tasks = [ LightevalTaskConfig( - name=f"cmrc2018_{Language.CHINESE.value}", + name=get_generative_task_name("cmrc2018", Language.CHINESE, cot), prompt_function=get_qa_prompt_function( Language.CHINESE, lambda line: { @@ -1128,26 +1083,30 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="clue/clue", hf_subset="cmrc2018", evaluation_splits=("trial",), few_shots_split="train", - generation_size=400, + generation_size=get_cot_generaion_size(cot, 400), metric=( - multilingual_quasi_exact_match_metric(Language.CHINESE, "prefix"), - multilingual_quasi_f1_score_metric(Language.CHINESE), + multilingual_quasi_exact_match_metric( + Language.CHINESE, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.CHINESE, normalize_pred=get_cot_answer_normalization(cot)), ), - stop_sequence=("\n",), + stop_sequence=get_stop_sequence(Language.CHINESE, cot), ) + for cot in [False, True] ] # IndicQA: A reading comprehension dataset for 11 Indian languages. # https://arxiv.org/abs/2407.13522 indicqa_tasks = [ LightevalTaskConfig( - name=f"indicqa_{language.value}", + name=get_generative_task_name("indicqa", language, cot), prompt_function=get_qa_prompt_function( language, lambda line: { @@ -1155,6 +1114,7 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="ai4bharat/IndicQA", @@ -1166,12 +1126,14 @@ trust_dataset=True, evaluation_splits=("test",), hf_avail_splits=("test",), - generation_size=400, + generation_size=get_cot_generaion_size(cot, 400), metric=( - multilingual_quasi_exact_match_metric(language, "prefix"), - multilingual_quasi_f1_score_metric(language), + multilingual_quasi_exact_match_metric( + language, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)), ), - stop_sequence=("\n",), + stop_sequence=get_stop_sequence(language, cot), ) for language in [ Language.ASSAMESE, @@ -1186,13 +1148,14 @@ Language.TAMIL, Language.TELUGU, ] + for cot in [False, True] ] # FQuAD v2: French Question Answering Dataset version 2. # https://arxiv.org/abs/2002.06071 fquad_v2_tasks = [ LightevalTaskConfig( - name=f"fquadv2_{Language.FRENCH.value}", + name=get_generative_task_name("fquadv2", Language.FRENCH, cot), prompt_function=get_qa_prompt_function( Language.FRENCH, lambda line: { @@ -1200,25 +1163,29 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="manu/fquad2_test", hf_subset="default", evaluation_splits=("test_hasAns",), few_shots_split="valid_hasAns", - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=( - multilingual_quasi_exact_match_metric(Language.FRENCH, "prefix"), - multilingual_quasi_f1_score_metric(Language.FRENCH), + multilingual_quasi_exact_match_metric( + Language.FRENCH, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.FRENCH, normalize_pred=get_cot_answer_normalization(cot)), ), + stop_sequence=get_stop_sequence(Language.FRENCH, cot), ) + for cot in [False, True] ] # TQuAD v2: Turkish Question Answering Dataset version 2. tquad_v2_tasks = [ LightevalTaskConfig( - name=f"tquadv2_{Language.TURKISH.value}", + name=get_generative_task_name("tquadv2", Language.TURKISH, cot), prompt_function=get_qa_prompt_function( Language.TURKISH, lambda line: { @@ -1226,19 +1193,23 @@ "context": line["context"], "choices": [a["text"] for a in line["answers"]], }, + cot=cot, ), suite=("lighteval",), hf_repo="erdometo/tquad2", hf_subset="default", evaluation_splits=("validation",), few_shots_split="train", - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=( - multilingual_quasi_exact_match_metric(Language.TURKISH, "prefix"), - multilingual_quasi_f1_score_metric(Language.TURKISH), + multilingual_quasi_exact_match_metric( + Language.TURKISH, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.TURKISH, normalize_pred=get_cot_answer_normalization(cot)), ), + stop_sequence=get_stop_sequence(Language.TURKISH, cot), ) + for cot in [False, True] ] # Other QA tasks for RC @@ -1247,7 +1218,7 @@ # https://arxiv.org/abs/2003.05002 tydiqa_tasks = [ LightevalTaskConfig( - name=f"tydiqa_{language.value}", + name=get_generative_task_name("tydiqa", language, cot), prompt_function=get_qa_prompt_function( language, lambda line: { @@ -1255,18 +1226,21 @@ "context": line["context"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="google-research-datasets/tydiqa", hf_subset="secondary_task", evaluation_splits=("validation",), few_shots_split="train", - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=( - multilingual_quasi_exact_match_metric(language, "prefix"), - multilingual_quasi_f1_score_metric(language), + multilingual_quasi_exact_match_metric( + language, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)), ), + stop_sequence=get_stop_sequence(language, cot), ) for language in [ Language.ENGLISH, @@ -1281,6 +1255,7 @@ Language.TELUGU, Language.THAI, ] + for cot in [False, True] ] # C3: A Chinese Challenge Corpus for Cross-lingual and Cross-modal Tasks @@ -1288,7 +1263,7 @@ # Paper: https://arxiv.org/abs/2004.05986 c3_tasks = [ LightevalTaskConfig( - name=f"c3_{Language.CHINESE.value}_{formulation.name.lower()}", + name=get_mcf_task_name("c3", Language.CHINESE, formulation), suite=("lighteval",), prompt_function=get_mcq_prompt_function( Language.CHINESE, @@ -1304,19 +1279,17 @@ hf_subset="c3", evaluation_splits=("validation",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.CHINESE, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # Other MCF tasks for RC @@ -1326,7 +1299,7 @@ # Paper: https://aclanthology.org/2023.arabicnlp-1.21/ race_ar_task = [ LightevalTaskConfig( - name=f"alghafa_race_{Language.ARABIC.value}_{formulation.name.lower()}", + name=get_mcf_task_name("alghafa_race", Language.ARABIC, formulation), prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation), suite=["lighteval"], hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated", @@ -1338,44 +1311,40 @@ evaluation_splits=["test"], few_shots_split="validation", trust_dataset=True, - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # SOQAL: A large-scale Arabic reading comprehension dataset. # https://arxiv.org/abs/1906.05394 soqal_tasks = [ LightevalTaskConfig( - name=f"soqal_{Language.ARABIC.value}_{formulation.name.lower()}", + name=get_mcf_task_name("soqal", Language.ARABIC, formulation), hf_subset="multiple_choice_grounded_statement_soqal_task", prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation), evaluation_splits=["test"], few_shots_split="validation", suite=["lighteval"], hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Native", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # MLQA (MultiLingual Question Answering) is a benchmark dataset for evaluating cross-lingual question answering performance. @@ -1384,7 +1353,7 @@ # Paper: https://arxiv.org/abs/1910.07475 mlqa_tasks = [ LightevalTaskConfig( - name=f"mlqa_{lang.value}", + name=get_generative_task_name("mlqa", lang, cot), prompt_function=get_qa_prompt_function( lang, lambda line: { @@ -1392,6 +1361,7 @@ "question": line["question"], "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0], }, + cot=cot, ), suite=("lighteval",), hf_repo="facebook/mlqa", @@ -1400,12 +1370,12 @@ trust_dataset=True, evaluation_splits=("test",), hf_avail_splits=["test"], - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=[ - multilingual_quasi_exact_match_metric(lang, "prefix"), - multilingual_quasi_f1_score_metric(lang), + multilingual_quasi_exact_match_metric(lang, "prefix", normalize_pred=get_cot_answer_normalization(cot)), + multilingual_quasi_f1_score_metric(lang, normalize_pred=get_cot_answer_normalization(cot)), ], + stop_sequence=get_stop_sequence(lang, cot), ) for lang in [ Language.ARABIC, @@ -1415,13 +1385,18 @@ Language.HINDI, Language.VIETNAMESE, ] + for cot in [False, True] ] # Belebele: A large-scale reading comprehension dataset covering 122 languages. # https://arxiv.org/abs/2308.16884 belebele_tasks = [ LightevalTaskConfig( - name=f"belebele_{language}_{formulation.name.lower()}", + name=get_mcf_task_name( + "belebele", + language, + formulation, + ), prompt_function=get_mcq_prompt_function( iso_639_3_ind_to_iso_639_3_macro[LangCodeLanguage.get(language).to_alpha3()], lambda line: { @@ -1437,15 +1412,20 @@ hf_subset=language, evaluation_splits=("test",), hf_avail_splits=["test"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence( + iso_639_3_ind_to_iso_639_3_macro[LangCodeLanguage.get(language).to_alpha3()], + formulation.cot, + ), ) - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for formulation in DEFAULT_FORMULATIONS for language in [ "acm_Arab", "arz_Arab", @@ -1589,10 +1569,6 @@ *race_ar_task, *belebele_tasks, *c3_tasks, - *squad_it_tasks, - *squad_es_tasks, - *faquad_tasks, - *germanquad_tasks, ] ) @@ -1670,7 +1646,7 @@ # Paper: https://arxiv.org/abs/2407.21783 meta_mmlu_tasks = [ LightevalTaskConfig( - name=f"meta_mmlu_{language.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('meta_mmlu', language, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( language, lambda line: { @@ -1691,14 +1667,16 @@ ), evaluation_splits=("latest",), hf_avail_splits=["latest"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for subset in MMLU_SUBSETS for language in [ @@ -1710,18 +1688,14 @@ Language.PORTUGUESE, Language.THAI, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # MLMM MMLU: Another multilingual version of MMLU # Paper: https://github.com/nlp-uoregon/mlmm-evaluation mlmm_mmlu_tasks = [ LightevalTaskConfig( - name=f"mlmm_mmlu_{language.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('mlmm_mmlu', language, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( language, lambda line: { @@ -1739,14 +1713,16 @@ trust_dataset=True, evaluation_splits=("test",), few_shots_split="dev", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for subset in MMLU_SUBSETS for language in [ @@ -1777,16 +1753,12 @@ Language.TELUGU, Language.KANNADA, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] openai_mmlu_tasks = [ LightevalTaskConfig( - name=f"openai_mmlu_{language[0].value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('openai_mmlu', language[0], formulation)}:{subset}", prompt_function=get_mcq_prompt_function( language[0], lambda line: { @@ -1803,14 +1775,16 @@ hf_avail_splits=["test"], hf_filter=partial(lambda subset, x: x["Subject"].lower() == subset, subset), hf_revision="038c7808122969ead7456361af05cb8f47d247f8", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language[0], [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(language[0], formulation.cot), ) for subset in MMLU_SUBSETS for language in [ @@ -1829,11 +1803,7 @@ (Language.YORUBA, "YO_NG"), (Language.CHINESE, "ZH_CN"), ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # Translated MMLU using both professional and non-professional translators. Contains tags for cultural sensitivity. @@ -1863,13 +1833,13 @@ lambda subset, sensitivity_label, x: x["subject"].lower() == subset and ( sensitivity_label == "ALL" or sensitivity_label in x["cultural_sensitivity_label"].replace("-", "UNK") - ) - and all(x[f"option_{opt}"] is not None and x[f"option_{opt}"].strip() for opt in "abcd"), + ), subset, sensitivity_label, ), - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), @@ -1914,11 +1884,7 @@ Language.YORUBA, Language.ZULU, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS for sensitivity_label in ["ALL", "CA", "CS", "UNK"] ] @@ -1936,7 +1902,7 @@ # From https://arxiv.org/abs/2406.03368. Human translated MMLU. afri_mmlu_tasks = [ LightevalTaskConfig( - name=f"afri_mmlu_{language.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('afri_mmlu', language, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( language, lambda line: { @@ -1954,14 +1920,16 @@ hf_filter=partial(lambda subset, line: line["subject"] == subset, subset), evaluation_splits=("test",), few_shots_split="dev", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for subset in AFRI_MMLU_SUBSETS for language in [ @@ -1983,18 +1951,14 @@ Language.YORUBA, # Language.ZULU, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # RUMMLU: Russian Massive Multitask Language Understanding # Paper: https://arxiv.org/html/2401.04531v2 rummlu = [ LightevalTaskConfig( - name=f"rummlu_{Language.RUSSIAN.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('rummlu', Language.RUSSIAN, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.RUSSIAN, lambda line: { @@ -2010,28 +1974,26 @@ hf_filter=lambda x: x["meta"]["domain"] == subset, evaluation_splits=("public_test",), hf_avail_splits=["public_test"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.RUSSIAN, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot), ) for subset in MMLU_SUBSETS - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # MMLU Turkish: Turkish version of MMLU # Translated using openai GPT mmlu_turkish = [ LightevalTaskConfig( - name=f"community_mmlu_{Language.TURKISH.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('community_mmlu', Language.TURKISH, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.TURKISH, lambda line: {"question": line["question"], "choices": line["choices"], "gold_idx": int(line["answer"])}, @@ -2042,21 +2004,19 @@ hf_subset=subset, evaluation_splits=("test",), few_shots_split="dev", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.TURKISH, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot), ) for subset in MMLU_SUBSETS - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # CMMLU: Chinese Massive Multitask Language Understanding @@ -2134,7 +2094,7 @@ cmmlu_tasks = [ LightevalTaskConfig( - name=f"cmmlu_{Language.CHINESE.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('cmmlu', Language.CHINESE, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.CHINESE, lambda line: { @@ -2149,21 +2109,19 @@ hf_subset=subset, evaluation_splits=("test",), few_shots_split="dev", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.CHINESE, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), ) for subset in CMMLU_SUBSETS - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # Arabic MMLU: Arabic version of MMLU @@ -2214,7 +2172,7 @@ arabic_mmlu_tasks = [ LightevalTaskConfig( - name=f"mmlu_{Language.ARABIC.value}_{formulation.name.lower()}:{normalize_subset(subset)}", + name=f"{get_mcf_task_name('mmlu', Language.ARABIC, formulation)}:{normalize_subset(subset)}", prompt_function=get_mcq_prompt_function( Language.ARABIC, lambda line: { @@ -2230,21 +2188,19 @@ hf_subset=subset, evaluation_splits=("test",), hf_avail_splits=["dev"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) for subset in ARABIC_MMLU_SUBSETS - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] @@ -2262,7 +2218,7 @@ turkish_mmlu_tasks = [ LightevalTaskConfig( - name=f"mmlu_{Language.TURKISH.value}_{formulation.name.lower()}:{normalize_subset(subset)}", + name=f"{get_mcf_task_name('mmlu', Language.TURKISH, formulation)}:{normalize_subset(subset)}", prompt_function=get_mcq_prompt_function( Language.TURKISH, lambda line: { @@ -2277,21 +2233,19 @@ hf_subset=subset, evaluation_splits=("test",), few_shots_split="dev", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.TURKISH, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot), ) for subset in TURKISH_MMLU_SUBSET - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -2323,7 +2277,7 @@ # github: https://github.com/nlp-uoregon/mlmm-evaluation mlmm_arc_challenge_tasks = [ LightevalTaskConfig( - name=f"mlmm_arc_{language.value}_{formulation.name.lower()}:challenge", + name=f"{get_mcf_task_name('mlmm_arc', language, formulation)}:challenge", prompt_function=get_mcq_prompt_function( language, lambda line: { @@ -2342,14 +2296,16 @@ trust_dataset=True, evaluation_splits=("test",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.RUSSIAN, @@ -2379,11 +2335,7 @@ Language.TELUGU, Language.KANNADA, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # Arabic ARC Easy @@ -2392,7 +2344,7 @@ # Paper: https://aclanthology.org/2023.arabicnlp-1.21/ arabic_ledarboard_arc_easy = [ LightevalTaskConfig( - name=f"alghafa_arc_{Language.ARABIC.value}_{formulation.name.lower()}:easy", + name=f"{get_mcf_task_name('alghafa_arc', Language.ARABIC, formulation)}:easy", prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation), suite=["lighteval"], hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated", @@ -2401,25 +2353,23 @@ trust_dataset=True, evaluation_splits=["test"], few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] lumi_arc = [ LightevalTaskConfig( - name=f"lumi_arc_{language.value}_{formulation.name.lower()}:challenge", + name=f"{get_mcf_task_name('lumi_arc', language, formulation)}:challenge", prompt_function=get_mcq_prompt_function( language, lambda line: { @@ -2436,20 +2386,18 @@ hf_subset=standardize_tag(language.value), evaluation_splits=["test"], few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS for language in [ Language.DANISH, Language.GERMAN, @@ -2469,7 +2417,7 @@ # Comes from the Turkish leaderboard turkish_arc_tasks = [ LightevalTaskConfig( - name=f"community_arc_{Language.TURKISH.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('community_arc', Language.TURKISH, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.TURKISH, lambda line: { @@ -2486,26 +2434,24 @@ hf_subset=f"ARC-{subset.capitalize()}", evaluation_splits=("test",), hf_avail_splits=["train"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.TURKISH, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ] + ([loglikelihood_acc_metric(normalization=LogProbPMINorm())] if subset == "challenge" else []), # type: ignore ), + stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot), ) for subset in ["easy", "challenge"] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] hindi_arc_tasks = [ LightevalTaskConfig( - name=f"community_arc_{Language.HINDI.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('community_arc', Language.HINDI, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.HINDI, lambda line: { @@ -2522,26 +2468,24 @@ hf_subset=f"ARC-{subset.capitalize()}", evaluation_splits=("test",), few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.HINDI, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ] + ([loglikelihood_acc_metric(normalization=LogProbPMINorm())] if subset == "challenge" else []), # type: ignore ), + stop_sequence=get_stop_sequence(Language.HINDI, formulation.cot), ) for subset in ["easy", "challenge"] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] arabic_arc_tasks = [ LightevalTaskConfig( - name=f"alghafa_arc_{Language.ARABIC.value}_{formulation.name.lower()}:easy", + name=f"{get_mcf_task_name('alghafa_arc', Language.ARABIC, formulation)}:easy", prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation), suite=["lighteval"], hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated", @@ -2550,25 +2494,23 @@ evaluation_splits=["test"], few_shots_split="validation", few_shots_select="sequential", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), trust_dataset=True, + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] swahili_arc_tasks = [ LightevalTaskConfig( - name=f"community_arc_{Language.SWAHILI.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('community_arc', Language.SWAHILI, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.SWAHILI, lambda line: { @@ -2588,21 +2530,19 @@ else "dc1df9df632d14c251594d9129fb833d2ca4429c", evaluation_splits=("test",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.SWAHILI, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ] + ([loglikelihood_acc_metric(normalization=LogProbPMINorm())] if subset == "challenge" else []), # type: ignore ), + stop_sequence=get_stop_sequence(Language.SWAHILI, formulation.cot), ) for subset in ["easy", "challenge"] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] @@ -2628,7 +2568,7 @@ # github: https://github.com/nlp-uoregon/mlmm-evaluation mlmm_truthfulqa_tasks = [ LightevalTaskConfig( - name=f"mlmm_truthfulqa_{language.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('mlmm_truthfulqa', language, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( language, partial( @@ -2648,13 +2588,15 @@ trust_dataset=True, evaluation_splits=("validation",), hf_avail_splits=["validation"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for subset in ["mc1", "mc2"] for language in [ @@ -2692,18 +2634,14 @@ Language.VIETNAMESE, Language.CHINESE, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # Turkish TruthfulQA # Based on turkish leaderboard turkish_truthfulqa = [ LightevalTaskConfig( - name=f"community_truthfulqa_{Language.TURKISH.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('community_truthfulqa', Language.TURKISH, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.TURKISH, partial( @@ -2721,20 +2659,18 @@ hf_subset="default", evaluation_splits=("validation",), hf_avail_splits=["validation"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.TURKISH, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot), ) for subset in ["mc1", "mc2"] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -2861,7 +2797,7 @@ exams_tasks = [ LightevalTaskConfig( - name=f"exams_{language.value}_{formulation.name.lower()}:{normalize_subset(subject)}", + name=f"{get_mcf_task_name('exams', language, formulation)}:{normalize_subset(subject)}", prompt_function=get_mcq_prompt_function( language, lambda line: { @@ -2884,21 +2820,19 @@ ), evaluation_splits=("test",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in exams_subjects_by_lang.keys() for subject in exams_subjects_by_lang[language] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # M3Exam: Multitask Multilingual Multimodal Evaluation Benchmark @@ -2906,7 +2840,7 @@ # Paper: https://arxiv.org/abs/2306.05179 m3exams_tasks = [ LightevalTaskConfig( - name=f"m3exams_{language.value}_{formulation.name.lower()}", + name=get_mcf_task_name("m3exams", language, formulation), suite=("lighteval",), prompt_function=get_mcq_prompt_function( language, @@ -2917,14 +2851,15 @@ hf_subset=LangCodeLanguage(standardize_tag(language.value)).language_name().lower(), evaluation_splits=("test",), few_shots_split="dev", - generation_size=-1, - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.AFRIKAANS, @@ -2937,11 +2872,7 @@ Language.THAI, Language.VIETNAMESE, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # Thai Exams @@ -2953,27 +2884,29 @@ thai_exams_tasks = [ LightevalTaskConfig( - name=f"thai_exams_{Language.THAI.value}_{formulation.name.lower()}:{subset}", - prompt_function=get_mcq_prompt_function(Language.THAI, thai_exams_adapter, formulation=formulation), + name=f"{get_mcf_task_name('thai_exams', Language.THAI, formulation)}:{subset}", + prompt_function=get_mcq_prompt_function( + Language.THAI, + thai_exams_adapter, + formulation=formulation, + ), suite=("lighteval",), hf_repo="scb10x/thai_exam", hf_subset=subset, evaluation_splits=("test",), few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.THAI, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.THAI, formulation.cot), ) for subset in THAI_EXAMS_SUBSETS - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -2992,7 +2925,7 @@ # Paper: https://arxiv.org/abs/2110.08462 xcsqa_tasks = [ LightevalTaskConfig( - name=f"xcsqa_{language.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('xcsqa', language, formulation)}", prompt_function=get_mcq_prompt_function( language, lambda line: { @@ -3004,20 +2937,22 @@ ), suite=("lighteval",), hf_repo="INK-USC/xcsr", - hf_subset=f"X-CSQA-{standardize_tag(language.value) if language != Language.JAPANESE else 'jap'}", + hf_subset=f"X-CSQA-{standardize_tag(language.value)}", hf_filter=lambda x: all( len(x["question"]["choices"]["text"][i].strip()) > 0 for i in range(len(x["question"]["choices"]["text"])) ), evaluation_splits=("validation",), hf_avail_splits=["validation"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ARABIC, @@ -3037,11 +2972,7 @@ Language.VIETNAMESE, Language.CHINESE, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -3059,29 +2990,30 @@ # Arabic version: https://aclanthology.org/2023.arabicnlp-1.21/ piqa_ar_tasks = [ LightevalTaskConfig( - name=f"alghafa_piqa_{Language.ARABIC.value}_{formulation.name.lower()}", - prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation), + name=f"{get_mcf_task_name('alghafa_piqa', Language.ARABIC, formulation)}", + prompt_function=get_mcq_prompt_function( + Language.ARABIC, + alghafa_adapter, + formulation=formulation, + ), suite=["lighteval"], hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated", hf_revision="08663706ee7cab30c4b7dc1bb00042a3227ce1ff", hf_subset="piqa_ar", hf_avail_splits=["test", "validation"], evaluation_splits=["test"], - few_shots_split="validation", trust_dataset=True, - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -3099,8 +3031,12 @@ # Arabic version: https://aclanthology.org/2023.arabicnlp-1.21/ openbook_ara_tasks = [ LightevalTaskConfig( - name=f"alghafa_openbookqa_{Language.ARABIC.value}_{formulation.name.lower()}", - prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation), + name=f"{get_mcf_task_name('alghafa_openbookqa', Language.ARABIC, formulation)}", + prompt_function=get_mcq_prompt_function( + Language.ARABIC, + alghafa_adapter, + formulation=formulation, + ), suite=["lighteval"], hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated", hf_subset="openbook_qa_ext_ar", @@ -3108,67 +3044,30 @@ trust_dataset=True, evaluation_splits=["test"], few_shots_split="validation", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] -# Spanish version of OpenBookQA from BSC Language Technology group -# Dataset: https://huggingface.co/datasets/BSC-LT/openbookqa-es -openbook_es_tasks = [ +# The Russian version is part of the MERA (Multilingual Enhanced Russian NLP Architectures) project. +# Paper: https://arxiv.org/abs/2401.04531 +openbook_rus_tasks = [ LightevalTaskConfig( - name=f"openbookqa_{Language.SPANISH.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('mera_openbookqa', Language.RUSSIAN, formulation)}", prompt_function=get_mcq_prompt_function( - Language.SPANISH, + Language.RUSSIAN, lambda line: { - "question": line["question_stem"], - "choices": line["choices"]["text"], - "gold_idx": LETTER_INDICES.index(line["answerKey"]), - }, - formulation=formulation, - ), - suite=["lighteval"], - hf_repo="BSC-LT/openbookqa-es", - hf_subset="default", - evaluation_splits=("test",), - few_shots_split="validation", - metric=get_metrics_for_formulation( - formulation, - [ - loglikelihood_acc_metric(normalization=LogProbTokenNorm()), - loglikelihood_acc_metric(normalization=LogProbCharNorm()), - ], - ), - ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] -] - - -# The Russian version is part of the MERA (Multilingual Enhanced Russian NLP Architectures) project. -# Paper: https://arxiv.org/abs/2401.04531 -openbook_rus_tasks = [ - LightevalTaskConfig( - name=f"mera_openbookqa_{Language.RUSSIAN.value}_{formulation.name.lower()}", - prompt_function=get_mcq_prompt_function( - Language.RUSSIAN, - lambda line: { - "question": line["inputs"]["question"], - "choices": [line["inputs"][f"option_{i.lower()}"] for i in LETTER_INDICES[:4]], - "gold_idx": LETTER_INDICES.index(line["outputs"]), + "question": line["inputs"]["question"], + "choices": [line["inputs"][f"option_{i.lower()}"] for i in LETTER_INDICES[:4]], + "gold_idx": LETTER_INDICES.index(line["outputs"]), }, formulation=formulation, ), @@ -3177,26 +3076,23 @@ hf_subset="ruopenbookqa", evaluation_splits=("train",), hf_avail_splits=["train"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.RUSSIAN, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( [ *openbook_rus_tasks, *openbook_ara_tasks, - *openbook_es_tasks, ] ) @@ -3209,7 +3105,7 @@ # Paper: https://aclanthology.org/2023.arabicnlp-1.21/ sciqa_ar_task = [ LightevalTaskConfig( - name=f"alghafa_sciqa_{Language.ARABIC.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('alghafa_sciqa', Language.ARABIC, formulation)}", prompt_function=get_mcq_prompt_function( Language.ARABIC, sciqa_adapter, @@ -3223,20 +3119,18 @@ evaluation_splits=["test"], few_shots_split="validation", few_shots_select="sequential", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.ARABIC, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), trust_dataset=True, + stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -3253,7 +3147,7 @@ # MERA: https://github.com/ai-forever/MERA mathlogicqa_rus_tasks = [ LightevalTaskConfig( - name=f"mathlogic_qa_{Language.RUSSIAN.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('mathlogic_qa', Language.RUSSIAN, formulation)}", prompt_function=get_mcq_prompt_function( Language.RUSSIAN, lambda line: { @@ -3268,69 +3162,64 @@ hf_subset="mathlogicqa", evaluation_splits=("train",), hf_avail_splits=["train"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.RUSSIAN, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot), ) - for formulation in [ - CFFormulation(), - MCFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] cmath_tasks = [ LightevalTaskConfig( - name=f"cmath_{Language.CHINESE.value}", - prompt_function=get_qa_prompt_function( + name=get_generative_task_name("cmath", Language.CHINESE, cot), + prompt_function=get_math_qa_prompt_function( Language.CHINESE, lambda line: { "question": line["question"], "choices": [line["golden"]], }, + cot=cot, ), suite=("lighteval",), hf_repo="weitianwen/cmath", hf_subset="default", evaluation_splits=("test",), few_shots_split="validation", - generation_size=25, + generation_size=get_cot_generaion_size(cot, 100), metric=[ - multilingual_quasi_exact_match_metric(Language.CHINESE, "full"), + multilingual_extractive_match_metric(Language.CHINESE), ], - stop_sequence=("\n",), + stop_sequence=get_stop_sequence(Language.CHINESE, cot), ) + for cot in (False, True) ] mgsm_tasks = [ LightevalTaskConfig( - name=f"mgsm_{language.value}", - prompt_function=get_qa_prompt_function( + name=get_generative_task_name("mgsm", language, cot), + prompt_function=get_math_qa_prompt_function( language, - lambda line: { - "question": line["question"], - # The cot is available but we have no use: - # line["answer"] - "choices": [str(line["answer_number"])], - }, + mgsm_adapter, + cot=cot, ), suite=("lighteval",), hf_repo="juletxara/mgsm", hf_subset=standardize_tag(language.value), evaluation_splits=("test",), few_shots_split="train", - generation_size=25, + generation_size=get_cot_generaion_size(cot, 100), metric=[ - multilingual_quasi_exact_match_metric(language, "full"), + multilingual_extractive_match_metric(language), ], - stop_sequence=("\n",), + stop_sequence=get_stop_sequence(language, cot), ) for language in [ - Language.ENGLISH, Language.SPANISH, Language.FRENCH, Language.GERMAN, @@ -3342,31 +3231,29 @@ Language.BENGALI, Language.TELUGU, ] + for cot in (False, True) ] + # African MGSM: MGSM for African Languages # From https://arxiv.org/abs/2406.03368. Human translated MGSM. afri_mgsm_tasks = [ LightevalTaskConfig( - name=f"afri_mgsm_{language.value}", - prompt_function=get_qa_prompt_function( + name=get_generative_task_name("mgsm", language, cot), + prompt_function=get_math_qa_prompt_function( language, - lambda line: { - "question": line["question"], - # The cot is available but we have no use: - # line["answer"] - "choices": [str(line["answer_number"])], - }, + mgsm_adapter, + cot=cot, ), suite=("lighteval",), hf_repo="masakhane/afrimgsm", hf_subset=language.value, evaluation_splits=("test",), few_shots_split="train", - generation_size=25, + generation_size=get_cot_generaion_size(cot, 100), metric=[ - multilingual_quasi_exact_match_metric(language, "full"), + multilingual_extractive_match_metric(language), ], - stop_sequence=("\n",), + stop_sequence=get_stop_sequence(language, cot), ) for language in [ Language.AMHARIC, @@ -3387,13 +3274,353 @@ Language.YORUBA, # Language.ZULU, ] + for cot in (False, True) +] + +# MSVAMP - Math Word Problems (Translated from SVAMP using Google Translate) +msvamp_tasks = [ + LightevalTaskConfig( + name=get_generative_task_name("msvamp", language, cot), + prompt_function=get_math_qa_prompt_function( + language, + lambda line: { + "question": line["m_query"], # Using the translated version of the question + "choices": [float_to_choice_string(line["response"])], # The answer as a string + }, + cot=cot, + ), + suite=("lighteval",), + hf_repo="Mathoctopus/MSVAMP", + hf_subset=standardize_tag(language.value), + evaluation_splits=("test",), + # Don't use balanced here as the biggest clusters are 1,2,3,4,5 which results + # in some llms just output 6 + few_shots_select="random", + hf_avail_splits=["test"], + generation_size=get_cot_generaion_size(cot, 100), + metric=[ + multilingual_extractive_match_metric(language), + ], + stop_sequence=get_stop_sequence(language, cot), + ) + for language in [ + Language.BENGALI, + Language.GERMAN, + Language.ENGLISH, + Language.SPANISH, + Language.FRENCH, + Language.JAPANESE, + Language.RUSSIAN, + Language.SWAHILI, + Language.THAI, + Language.CHINESE, + ] + for cot in (False, True) +] + + +# CMM-Math - Chinese Multimodal Math Dataset +# CMM-Math is a comprehensive Chinese mathematical reasoning dataset containing over 28,000 high-quality samples +# across 12 grade levels from elementary to high school. It includes multiple-choice and fill-in-the-blank questions +# with detailed solutions. The dataset features both text-only and multimodal problems (with visual context in +# questions/options). It consists of 22k+ training samples and 5k+ evaluation samples, designed to evaluate and +# enhance mathematical reasoning capabilities of large language and multimodal models. +# Note: Only the MCQ subset is implemented +cmm_math_mc_tasks = [ + LightevalTaskConfig( + name=get_mcf_task_name("cmm_math", Language.CHINESE, formulation), + prompt_function=get_mcq_prompt_function( + Language.CHINESE, + cmm_math_adapter, + formulation=formulation, + ), + suite=("lighteval",), + hf_repo="ecnu-icalk/cmm-math", + hf_subset="default", + hf_filter=lambda x: x["image"] == "[]" + and x["options"] + != "[]", # Only include examples without images (it's a string for some reason) and mcq questions + evaluation_splits=("test",), + few_shots_split="train", + metric=get_metrics_for_mcq_formulation( + formulation, + Language.CHINESE, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), + ) + for formulation in DEFAULT_FORMULATIONS +] + + +# Math23K - Chinese Math Word Problem Dataset +# Math23K is a dataset containing 23,162 Chinese math word problems crawled from the internet. +# Originally introduced in "Deep Neural Solver for Math Word Problems", it consists of math word +# problems in Chinese text along with their corresponding numerical answers. The dataset is split +# into training and test sets and is commonly used to evaluate mathematical reasoning capabilities +# of language models on Chinese text. Each example contains the original Chinese problem text and +# its corresponding numerical solution. +math23k_tasks = [ + LightevalTaskConfig( + name=get_generative_task_name("math23k", Language.CHINESE, cot), + prompt_function=get_math_qa_prompt_function( + Language.CHINESE, + lambda line: { + "question": line["original_text"], # Use the original Chinese text + "choices": [str(line["ans"])], # Answer has computed number while ans is symbolic + }, + cot=cot, + ), + suite=("lighteval",), + hf_repo="Gxg/Math23K", + hf_subset="default", + evaluation_splits=("test",), + few_shots_split="train", + generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks like msvamp + metric=[multilingual_extractive_match_metric(Language.CHINESE)], + stop_sequence=get_stop_sequence(Language.CHINESE, cot), + ) + for cot in (False, True) +] + +# TAL-SCQ5K - Chinese Math Word Problem Dataset +# TAL-SCQ5K is a high-quality mathematical competition dataset created by TAL Education Group, +# consisting of 5K multiple-choice questions (3K training, 2K testing) covering math topics from +# primary through high school levels. The dataset includes detailed solution steps and standardized +# LaTeX expressions. + +tal_scq5k_tasks = [ + LightevalTaskConfig( + name=get_mcf_task_name("tal_scq5k", Language.CHINESE, formulation), + prompt_function=get_mcq_prompt_function( + Language.CHINESE, + lambda line: { + "question": line["problem"].strip(), + "choices": [ + opt[0]["content"] for opt in sorted(line["answer_option_list"], key=lambda x: x[0]["aoVal"]) + ], + "gold_idx": LETTER_INDICES.index(line["answer_value"]), + }, + formulation=formulation, + ), + suite=("lighteval",), + hf_repo="math-eval/TAL-SCQ5K", + hf_subset="default", + evaluation_splits=("test",), + few_shots_split="train", + metric=get_metrics_for_mcq_formulation( + formulation, + Language.CHINESE, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), + ) + for formulation in DEFAULT_FORMULATIONS +] + +# MathQA-TR - Turkish Math Question Answering Dataset +# Sourced from https://github.com/esingedik/Turkish-MWP-Corpora-and-Code +# Translated using Google Translate +# MWP is collected from MAWPS, ASDiv-A, SVAMP. +mathqa_tr_tasks = [ + LightevalTaskConfig( + name=get_generative_task_name("mathqa", Language.TURKISH, cot), + prompt_function=get_math_qa_prompt_function( + Language.TURKISH, + lambda line: { + "question": line["question"], + "choices": [line["answer"]], + }, + cot=cot, + ), + suite=("lighteval",), + hf_repo="lighteval/MathQA-TR", + hf_subset="default", + evaluation_splits=("test",), + few_shots_split="train", + generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks + metric=[ + multilingual_extractive_match_metric(Language.TURKISH), + ], + stop_sequence=get_stop_sequence(Language.TURKISH, cot), + ) + for cot in (False, True) +] + +mwp_tr_tasks = [ + LightevalTaskConfig( + name=get_generative_task_name("mwp", Language.TURKISH, cot), + prompt_function=get_math_qa_prompt_function( + Language.TURKISH, + lambda line: { + "question": line["question"], + "choices": [line["answer"]], + }, + cot=cot, + ), + suite=("lighteval",), + hf_repo="lighteval/MWP-TR", + hf_subset="default", + evaluation_splits=("test",), + few_shots_split="train", + generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks + metric=[ + multilingual_extractive_match_metric(Language.TURKISH), + ], + stop_sequence=get_stop_sequence(Language.TURKISH, cot), + ) + for cot in (False, True) +] + +# MERA Arithmetic Tasks +# Paper: https://arxiv.org/abs/2401.04531 +mera_arithmetic_tasks = [ + LightevalTaskConfig( + name=get_generative_task_name("mera_arithmetic", Language.RUSSIAN, cot), + prompt_function=get_math_qa_prompt_function( + Language.RUSSIAN, + lambda line: { + "question": line["inputs"], + "choices": [str(line["outputs"])], + }, + cot=cot, + ), + suite=("lighteval",), + hf_repo="ai-forever/MERA", + hf_subset=subset, + evaluation_splits=("public_test",) + if subset == "rumodar" + else ("train",), # MERA uses train split for evaluation + hf_avail_splits=["public_test"] if subset == "rumodar" else ["train"], + generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks + metric=[ + Metrics.quasi_exact_match_math, + ], + stop_sequence=get_stop_sequence(Language.RUSSIAN, cot), + ) + for subset in ["rumodar", "rumultiar", "simplear"] + for cot in (False, True) +] + +# QazUNTv2 Tasks - High school math problems in English and Russian +# A bilingual dataset for evaluating LLMs on high school math problems covering: +# - Algebra (436 problems) +# - Logic (312 problems) +# - Probability (163 problems) +# Each problem includes multiple choice options and detailed solutions. +# The dataset was manually curated, with English translations via Google Translate. +# Paper: https://doi.org/10.17632/52vc6v4czj.1 + +qazuntv2_tasks = [ + LightevalTaskConfig( + name=get_mcf_task_name("qazuntv2", lang, formulation), + prompt_function=get_mcq_prompt_function( + lang, + qazuntv2_adapter, + formulation=formulation, + ), + suite=("lighteval",), + hf_repo="lighteval/QazUNTv2", + hf_subset=standardize_tag(lang.value), + hf_filter=lambda x: x["section"].lower() == subset, + evaluation_splits=("train",), # Dataset only has train split + hf_avail_splits=["train"], + metric=get_metrics_for_mcq_formulation( + formulation, + lang, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + stop_sequence=get_stop_sequence(lang, formulation.cot), + ) + for lang in [ + Language.ENGLISH, + Language.RUSSIAN, + ] + for subset in ["algebra", "logic", "probability"] + for formulation in DEFAULT_FORMULATIONS ] + + +# ArMATH - Arabic Math Word Problems +# A dataset of 6,000 primary-school math word problems in Modern Standard Arabic (MSA). +# Paper: https://github.com/reem-codes/ArMATH +armath_tasks = [ + LightevalTaskConfig( + name=get_generative_task_name("armath", Language.ARABIC, cot), + prompt_function=get_math_qa_prompt_function( + Language.ARABIC, + lambda line: { + "question": line["question"], + "choices": [float_to_choice_string(line["answer"])], # Evaluate the equation to get answer + }, + cot=cot, + ), + suite=("lighteval",), + hf_repo="khalidalt/arMath", + hf_subset="default", + evaluation_splits=("test",), # Dataset only has test split + few_shots_split="validation", + generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks + metric=[ + multilingual_extractive_match_metric(Language.ARABIC), + ], + stop_sequence=get_stop_sequence(Language.ARABIC, cot), + ) + for cot in (False, True) +] + +# HAWP - Hindi Arithmetic Word Problems +# A dataset containing 2.3k arithmetic word problems in Hindi, designed to evaluate +# mathematical reasoning capabilities of language models. Each problem includes the +# question text, equation, and numerical answer. +hawp_tasks = [ + LightevalTaskConfig( + name=get_generative_task_name("hawp", Language.HINDI, cot), + prompt_function=get_math_qa_prompt_function( + Language.HINDI, + lambda line: { + "question": line["Problem"], + "choices": [str(line["answer"])], + }, + cot=cot, + ), + suite=("lighteval",), + hf_repo="lighteval/HAWP", + hf_subset="default", + evaluation_splits=("test",), + few_shots_split="dev", + generation_size=get_cot_generaion_size(cot, 100), + metric=[multilingual_extractive_match_metric(Language.HINDI, precision=6)], + stop_sequence=get_stop_sequence(Language.HINDI, cot), + ) + for cot in (False, True) +] + TASKS_TABLE.extend( [ *cmath_tasks, *mathlogicqa_rus_tasks, *mgsm_tasks, *afri_mgsm_tasks, + *armath_tasks, + *msvamp_tasks, + *cmm_math_mc_tasks, + *math23k_tasks, + *tal_scq5k_tasks, + *mathqa_tr_tasks, + *mwp_tr_tasks, + *mera_arithmetic_tasks, + *qazuntv2_tasks, + *hawp_tasks, ] ) @@ -3417,7 +3644,7 @@ agieval_tasks_zh = [ LightevalTaskConfig( - name=f"agieval_{Language.CHINESE.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('agieval', Language.CHINESE, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.CHINESE, partial( @@ -3433,21 +3660,19 @@ evaluation_splits=("test",), hf_avail_splits=["test"], few_shots_split=None, - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.CHINESE, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), loglikelihood_acc_metric(normalization=LogProbPMINorm()), ], ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), ) for subset in CHINESE_AGIEVAL_SUBSET - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] # C-Eval: Chinese Evaluation suite # Similar to MMLu but with different categories @@ -3467,7 +3692,6 @@ "high_school_mathematics", "high_school_physics", "high_school_chemistry", - "high_school_biology", "middle_school_mathematics", "middle_school_biology", "middle_school_physics", @@ -3509,7 +3733,7 @@ ceval_tasks = [ LightevalTaskConfig( - name=f"ceval_{Language.CHINESE.value}_{formulation.name.lower()}:{subset}", + name=f"{get_mcf_task_name('ceval', Language.CHINESE, formulation)}:{subset}", prompt_function=get_mcq_prompt_function( Language.CHINESE, partial( @@ -3524,92 +3748,18 @@ hf_subset=subset, evaluation_splits=("val",), few_shots_split="dev", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.CHINESE, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot), ) for subset in CEVAL_SUBSET - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] -] - - -# OAB Exams: A collection of questions from the Brazilian Bar Association exam -# The exam is required for anyone who wants to practice law in Brazil -# Dataset: https://huggingface.co/datasets/eduagarcia/oab_exams -oab_exams_tasks = [ - LightevalTaskConfig( - name=f"oab_exams_{Language.PORTUGUESE.value}_{formulation.name.lower()}", - prompt_function=get_mcq_prompt_function( - Language.PORTUGUESE, - lambda line: { - "question": line["question"], - "choices": line["choices"]["text"], - "gold_idx": LETTER_INDICES.index(line["answerKey"]), - }, - formulation=formulation, - ), - suite=("lighteval",), - hf_repo="eduagarcia/oab_exams", - hf_subset="default", - evaluation_splits=("train",), - hf_avail_splits=["train"], - metric=get_metrics_for_formulation( - formulation, - [ - loglikelihood_acc_metric(normalization=LogProbTokenNorm()), - loglikelihood_acc_metric(normalization=LogProbCharNorm()), - ], - ), - ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] -] - -# ENEM (Exame Nacional do Ensino Médio) is a standardized Brazilian national secondary -# education examination. The exam is used both as a university admission test and as a -# high school evaluation test. -# Dataset: https://huggingface.co/datasets/maritaca-ai/enem -enem_tasks = [ - LightevalTaskConfig( - name=f"enem_{Language.PORTUGUESE.value}_{formulation.name.lower()}:{year}", - prompt_function=get_mcq_prompt_function( - Language.PORTUGUESE, - partial( - enem_adapter, - Language.PORTUGUESE, - ), - formulation=formulation, - ), - suite=("lighteval",), - hf_repo="maritaca-ai/enem", - hf_subset=year, - evaluation_splits=("train",), - hf_avail_splits=["train"], - metric=get_metrics_for_formulation( - formulation, - [ - loglikelihood_acc_metric(normalization=LogProbTokenNorm()), - loglikelihood_acc_metric(normalization=LogProbCharNorm()), - ], - ), - ) - for year in ["2022", "2023", "2024"] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] @@ -3619,7 +3769,7 @@ # MERA: https://github.com/ai-forever/MERA worldtree_rus_tasks = [ LightevalTaskConfig( - name=f"mera_worldtree_{Language.RUSSIAN.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('mera_worldtree', Language.RUSSIAN, formulation)}", prompt_function=get_mcq_prompt_function( Language.RUSSIAN, lambda line: { @@ -3634,19 +3784,17 @@ hf_subset="ruworldtree", evaluation_splits=("train",), hf_avail_splits=["train"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + Language.RUSSIAN, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -3654,8 +3802,6 @@ *agieval_tasks_zh, *worldtree_rus_tasks, *ceval_tasks, - *oab_exams_tasks, - *enem_tasks, ] ) @@ -3663,20 +3809,22 @@ # ------------------------------- Continuation Tasks ------------------------------- # xcodah_tasks = [ LightevalTaskConfig( - name=f"xcodah_{language.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('xcodah', language, formulation)}", prompt_function=get_mcq_prompt_function(language, partial(xcodah_adapter, language), formulation=formulation), suite=("lighteval",), hf_repo="INK-USC/xcsr", - hf_subset=f"X-CODAH-{standardize_tag(language.value) if language != Language.JAPANESE else 'jap'}", + hf_subset=f"X-CODAH-{standardize_tag(language.value)}", evaluation_splits=("validation",), hf_avail_splits=["validation"], - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + language, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ARABIC, @@ -3696,16 +3844,12 @@ Language.VIETNAMESE, Language.CHINESE, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] xstory_tasks = [ LightevalTaskConfig( - name=f"xstory_cloze_{lang.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('xstory_cloze', lang, formulation)}", prompt_function=get_continuation_prompt_function( lang, partial( @@ -3730,13 +3874,15 @@ hf_subset=standardize_tag(lang.value), evaluation_splits=["eval"], few_shots_split="train", - metric=get_metrics_for_formulation( + metric=get_metrics_for_mcq_formulation( formulation, + lang, [ loglikelihood_acc_metric(normalization=LogProbTokenNorm()), loglikelihood_acc_metric(normalization=LogProbCharNorm()), ], ), + stop_sequence=get_stop_sequence(lang, formulation.cot), ) for lang in [ Language.RUSSIAN, @@ -3750,11 +3896,7 @@ Language.BASQUE, Language.BURMESE, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -3768,20 +3910,27 @@ xwinograd_tasks = [ LightevalTaskConfig( - name=f"xwinograd_{language.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('xwinograd', language, formulation)}", suite=("lighteval",), prompt_function=get_continuation_prompt_function( - language, partial(winogrand_adapter, language), formulation=formulation + language, + partial(winogrand_adapter, language), + formulation=formulation, ), hf_repo="Muennighoff/xwinograd", - hf_subset=standardize_tag(language.value) if language != Language.JAPANESE else "jp", + hf_subset=standardize_tag(language.value), evaluation_splits=("test",), hf_avail_splits=["test"], - metric=[ - loglikelihood_acc_metric(normalization=None), - loglikelihood_acc_metric(normalization=LogProbTokenNorm()), - loglikelihood_acc_metric(normalization=LogProbCharNorm()), - ], + metric=get_metrics_for_mcq_formulation( + formulation, + language, + [ + loglikelihood_acc_metric(normalization=None), + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + stop_sequence=get_stop_sequence(language, formulation.cot), ) for language in [ Language.ENGLISH, @@ -3791,35 +3940,34 @@ Language.RUSSIAN, Language.CHINESE, ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] winograd_turkish_task = [ LightevalTaskConfig( - name=f"community_xwinograd_{Language.TURKISH.value}_{formulation.name.lower()}", + name=f"{get_mcf_task_name('community_xwinograd', Language.TURKISH, formulation)}", suite=("lighteval",), prompt_function=get_continuation_prompt_function( - Language.TURKISH, partial(winogrand_adapter, Language.TURKISH), formulation=formulation + Language.TURKISH, + partial(winogrand_adapter, Language.TURKISH), + formulation=formulation, ), hf_repo="malhajar/winogrande-tr-v0.2", hf_subset="default", evaluation_splits=("validation",), few_shots_split="train", - metric=[ - loglikelihood_acc_metric(normalization=None), - loglikelihood_acc_metric(normalization=LogProbTokenNorm()), - loglikelihood_acc_metric(normalization=LogProbCharNorm()), - ], + metric=get_metrics_for_mcq_formulation( + formulation, + Language.TURKISH, + [ + loglikelihood_acc_metric(normalization=None), + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot), ) - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), - ] + for formulation in DEFAULT_FORMULATIONS ] TASKS_TABLE.extend( @@ -3844,8 +3992,8 @@ mkqa_tasks = [ LightevalTaskConfig( - name=f"mkqa_{language.value}:{subset}", - prompt_function=get_qa_prompt_function(language, partial(get_mkqa_adapter, language)), + name=get_generative_task_name("mkqa", language, cot), + prompt_function=get_qa_prompt_function(language, partial(get_mkqa_adapter, language), cot=cot), suite=("lighteval",), hf_repo="apple/mkqa", hf_subset="mkqa", @@ -3861,15 +4009,17 @@ trust_dataset=True, evaluation_splits=("train",), hf_avail_splits=["train"], - stop_sequence=("\n",), metric=[ - multilingual_quasi_exact_match_metric(language, "prefix"), - multilingual_quasi_f1_score_metric(language), + multilingual_quasi_exact_match_metric( + language, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)), ] if subset in ["entity", "long_answer", "short_phrase"] else [ - multilingual_quasi_exact_match_metric(language, "full"), + multilingual_quasi_exact_match_metric(language, "full", normalize_pred=get_cot_answer_normalization(cot)), ], + stop_sequence=get_stop_sequence(language, cot), ) for subset in MKQA_TASK_TO_ID.keys() for language in [ @@ -3900,6 +4050,7 @@ # Language.CHINESE_HONG_KONG, # Language.CHINESE_TRADITIONAL, ] + for cot in (False, True) ] mintaka_tasks = [ @@ -3911,18 +4062,19 @@ "question": line["question"], "choices": [line["answerText"]], }, + cot=cot, ), suite=("lighteval",), hf_repo="AmazonScience/mintaka", hf_subset=standardize_tag(lang.value), evaluation_splits=("test",), few_shots_split="train", - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=[ multilingual_quasi_exact_match_metric(lang, "prefix"), multilingual_quasi_f1_score_metric(lang), ], + stop_sequence=get_stop_sequence(lang, cot=cot), ) for lang in [ Language.ARABIC, @@ -3935,55 +4087,64 @@ Language.JAPANESE, Language.PORTUGUESE, ] + for cot in (False, True) ] french_triviqa_tasks = [ LightevalTaskConfig( - name=f"community_triviaqa_{Language.FRENCH.value}", + name=get_generative_task_name("community_triviaqa", Language.FRENCH, cot), prompt_function=get_qa_prompt_function( Language.FRENCH, lambda line: { "question": line["Question"], "choices": [line["Answer"]], }, + cot=cot, ), suite=("lighteval",), hf_repo="manu/french-trivia", hf_subset="default", evaluation_splits=("train",), hf_avail_splits=["train"], - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=[ - multilingual_quasi_exact_match_metric(Language.FRENCH, "prefix"), - multilingual_quasi_f1_score_metric(Language.FRENCH), + multilingual_quasi_exact_match_metric( + Language.FRENCH, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.FRENCH, normalize_pred=get_cot_answer_normalization(cot)), ], + stop_sequence=get_stop_sequence(Language.FRENCH, cot), ) + for cot in (False, True) ] chegeka_tasks = [ LightevalTaskConfig( - name=f"chegeka_{Language.RUSSIAN.value}", + name=get_generative_task_name("chegeka", Language.RUSSIAN, cot), prompt_function=get_qa_prompt_function( Language.RUSSIAN, lambda line: { "question": line["inputs"]["text"], "choices": [line["outputs"]], }, + cot=cot, ), suite=("lighteval",), hf_repo="ai-forever/MERA", hf_subset="chegeka", evaluation_splits=("train",), hf_avail_splits=["train"], - generation_size=400, - stop_sequence=("\n",), + generation_size=get_cot_generaion_size(cot, 400), metric=[ - multilingual_quasi_exact_match_metric(Language.RUSSIAN, "prefix"), - multilingual_quasi_f1_score_metric(Language.RUSSIAN), + multilingual_quasi_exact_match_metric( + Language.RUSSIAN, "prefix", normalize_pred=get_cot_answer_normalization(cot) + ), + multilingual_quasi_f1_score_metric(Language.RUSSIAN, normalize_pred=get_cot_answer_normalization(cot)), ], + stop_sequence=get_stop_sequence(Language.RUSSIAN, cot), ) + for cot in (False, True) ] TASKS_TABLE.extend( @@ -4061,7 +4222,7 @@ acva_tasks = [ LightevalTaskConfig( - name=f"acva_{Language.ARABIC.value}:{subset}", + name=get_generative_task_name("acva", Language.ARABIC, cot), prompt_function=get_boolq_prompt_function( Language.ARABIC, lambda line: { @@ -4075,11 +4236,17 @@ hf_subset=subset, evaluation_splits=("test",), few_shots_split="validation", - metric=[multilingual_quasi_exact_match_metric(Language.ARABIC, "full"), loglikelihood_acc_metric()], - generation_size=5, + metric=[ + multilingual_quasi_exact_match_metric( + Language.ARABIC, "full", normalize_pred=get_cot_answer_normalization(cot) + ), + loglikelihood_acc_metric(), + ], + generation_size=get_cot_generaion_size(cot, 5), stop_sequence=("\n",), ) for subset in ACVA_SUBSET + for cot in (False, True) ] @@ -4368,7 +4535,7 @@ def flores_adapter(lang1, lang2): source_language=Language(manage_duplicate_language_codes(lang1.split("_")[0])), target_language=Language(manage_duplicate_language_codes(lang2.split("_")[0])), adapter=flores_adapter(lang1, lang2), - formulation=CFFormulation(), + formulation=CFFormulation(cot=cot), ), suite=("lighteval",), hf_repo="facebook/flores", @@ -4377,13 +4544,19 @@ def flores_adapter(lang1, lang2): evaluation_splits=["devtest"], few_shots_split="dev", few_shots_select=None, - generation_size=300, - metric=[Metrics.chrf_plus, Metrics.bleu, Metrics.bleu_1, Metrics.bleu_4], - stop_sequence=["\n"], + generation_size=get_cot_generaion_size(cot, 300), + metric=[ + translation_metric(metric_name="chrf++", normalize_pred=get_cot_answer_normalization(cot)), + translation_metric(metric_name="bleu", normalize_pred=get_cot_answer_normalization(cot)), + translation_metric(metric_name="bleu_1", normalize_pred=get_cot_answer_normalization(cot)), + translation_metric(metric_name="bleu_4", normalize_pred=get_cot_answer_normalization(cot)), + ], + stop_sequence=get_stop_sequence(Language.ENGLISH, cot), trust_dataset=True, version=0, ) for (lang1, lang2) in combinations(flores_200_languages, 2) + for cot in (False, True) ] TASKS_TABLE.extend( diff --git a/src/lighteval/tasks/multilingual/utils/adapters_utils.py b/src/lighteval/tasks/multilingual/utils/adapters_utils.py index 2e7f27dda..e30bbf512 100644 --- a/src/lighteval/tasks/multilingual/utils/adapters_utils.py +++ b/src/lighteval/tasks/multilingual/utils/adapters_utils.py @@ -133,3 +133,8 @@ def extract_answer(acc: tuple[str, int, list[str]], symbol: str) -> tuple[str, i answer[: len(prefix)]: answer[len(prefix) :].strip() for answer, prefix in zip(found_answers, answer_prefixes) } return last_index, prefix_answer_dict + + +def float_to_choice_string(answer: float) -> str: + answer = float(answer) + return str(int(answer)) if answer.is_integer() else str(answer) diff --git a/src/lighteval/tasks/multilingual/utils/task_utils.py b/src/lighteval/tasks/multilingual/utils/task_utils.py index d8e73dac8..29053ec2f 100644 --- a/src/lighteval/tasks/multilingual/utils/task_utils.py +++ b/src/lighteval/tasks/multilingual/utils/task_utils.py @@ -21,22 +21,74 @@ # SOFTWARE. -from lighteval.metrics.dynamic_metrics import loglikelihood_acc_metric +import re +from typing import Callable + +from lighteval.metrics.dynamic_metrics import ( + IndicesExtractionConfig, + loglikelihood_acc_metric, + multilingual_extractive_match_metric, +) from lighteval.metrics.utils.metric_utils import Metric from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation +from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS +from lighteval.utils.language import Language def normalize_subset(subset: str) -> str: return subset.replace(" ", "_").replace("(", "").replace(")", "").lower() -def get_metrics_for_formulation(formulation: Formulation, metrics: list[Metric]) -> list[Metric]: +def get_metrics_for_mcq_formulation( + formulation: Formulation, language: Language, metrics: list[Metric] +) -> list[Metric]: """ Choose the appropriate metrics for the given formulation otherwise fallback to the original metrics. """ match formulation: - # - case MCFFormulation(choice_prefix="Letters"): + case MCFFormulation(choice_prefix="Letters" | "Numbers", cot=False): return [loglikelihood_acc_metric(normalization=None)] + # In case of NativeLetters we can't use just acc_metric, because the letters can be made of multiple tokens + case MCFFormulation(cot=False): + return metrics + case MCFFormulation(cot=True): + return [ + multilingual_extractive_match_metric( + language, + gold_extraction_target=(IndicesExtractionConfig(prefix_for_extraction=formulation.choice_prefix),), + ), + ] case _: return metrics + + +def get_cot_generaion_size(cot: bool, generation_size: int) -> int | None: + return None if cot else generation_size + + +def get_stop_sequence(language: Language, cot: bool) -> list[str] | None: + stop_sequence = ["\n"] if cot else [] + try: + trans = TRANSLATION_LITERALS[language] + # Ensure it's on a new line as otherwise LLM's sometimes like to generate: + # "**Répondez à la" or "1. **Comprendre la" in their cot generations + return [ + f"\n{trans.question_word}{trans.colon}", + f"\n{trans.question_word.capitalize()}{trans.colon}", + ] + stop_sequence + except (AttributeError, KeyError): + return stop_sequence + + +def get_cot_answer_normalization(cot: bool = False) -> Callable[[str], str] | None: + if not cot: + return None + + compiled_b_regex = re.compile(r"(.*?)") + + def bb_normalizer(text: str) -> str: + matches = compiled_b_regex.findall(text) + last = matches[-1] if matches else text + return last + + return bb_normalizer diff --git a/src/lighteval/tasks/templates/boolq.py b/src/lighteval/tasks/templates/boolq.py index 09959e874..e0b19575b 100644 --- a/src/lighteval/tasks/templates/boolq.py +++ b/src/lighteval/tasks/templates/boolq.py @@ -46,6 +46,7 @@ class BoolQInput(TypedDict): answer: bool instruction: NotRequired[str] context: NotRequired[str] + few_shot_cot: NotRequired[str] class BoolQAdapter(TypedDict): @@ -62,6 +63,7 @@ class BoolQAdapter(TypedDict): answer: str instruction: NotRequired[str] context: NotRequired[str] + few_shot_cot: NotRequired[str] def get_boolq_prompt_function( @@ -90,6 +92,7 @@ def get_boolq_prompt_function( "context": "context", "instruction": "instruction", "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", }, formulation, ) @@ -105,11 +108,13 @@ def boolq_prompt( choices = [translation_literals.yes, translation_literals.no] return mcq_prompt_fn( { + **{x: line[x] for x in line if x.startswith("__")}, "instruction": input_data.get("instruction", ""), "question": input_data["question"], "context": input_data.get("context", ""), "choices": choices, "gold_idx": 0 if input_data["answer"] else 1, + "few_shot_cot": input_data.get("few_shot_cot", ""), }, task_name, ) diff --git a/src/lighteval/tasks/templates/continuation.py b/src/lighteval/tasks/templates/continuation.py index c9cd5d1bc..9266986c3 100644 --- a/src/lighteval/tasks/templates/continuation.py +++ b/src/lighteval/tasks/templates/continuation.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import logging from typing import Callable from typing_extensions import NotRequired, TypedDict @@ -44,9 +45,11 @@ from lighteval.utils.utils import as_list +logger = logging.getLogger(__name__) + CONTINUATION_QUERY_CF = "{instruction}{context}" -CONTINUATION_QUERY_MCF = "{instruction}{context}\n{options}{answer_word}{colon}" +CONTINUATION_QUERY_MCF = "{instruction}{context}\n\n{options_word}{colon}\n{options}{answer_word}{colon}" # Defined for type hinting only @@ -64,6 +67,7 @@ class ContinuationInput(TypedDict): continuations: list[str] gold_idx: list[int] | int instruction: NotRequired[str] + few_shot_cot: NotRequired[str] class ContinuationDictAdapter(TypedDict): @@ -80,6 +84,7 @@ class ContinuationDictAdapter(TypedDict): continuations: str gold_idx: str instruction: NotRequired[str] + few_shot_cot: NotRequired[str] def get_continuation_prompt_function( @@ -107,6 +112,8 @@ def get_continuation_prompt_function( *MCF* Context + + Options: A. Continuation 1 B. Continuation 2 C. Continuation 3 @@ -126,13 +133,31 @@ def get_continuation_prompt_function( adapter_fn = create_adapter_from_dict(adapter) translation_literals = TRANSLATION_LITERALS[language] + WARNED_ABOUT_COT_INSTRUCTION = False + def prepare_prompt(line: dict): cont_input = adapter_fn(line) if cont_input is None: return None instruction_val = cont_input.get("instruction") - instruction = f"{instruction_val}\n" if instruction_val else "" + if formulation.cot and not instruction_val: + if not isinstance(formulation, MCFFormulation) or formulation.choice_prefix not in [ + "Letters", + "NativeLetters", + ]: + raise ValueError( + "You are using a COT with a unsupported formulation. Either use MCF formulation or provide an instruction." + ) + + instruction_val = f"{translation_literals.continuation_mcf_instruction}\n{translation_literals.default_formatting_instruction}" + nonlocal WARNED_ABOUT_COT_INSTRUCTION + if not WARNED_ABOUT_COT_INSTRUCTION: + logger.warning( + f" You are using a COT with MCF formulation but did not provide an instruction. Defaulting to {instruction_val}" + ) + WARNED_ABOUT_COT_INSTRUCTION = True + instruction = f"{instruction_val}\n\n" if instruction_val else "" context = ( f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}" @@ -182,16 +207,31 @@ def prompt_fn_mcf(line, task_name: str): options = build_choices(continuations, formulation, translation_literals) options = f"{options}\n" if options else "" - answers = build_answers(continuations, formulation, translation_literals) - - answer_word = capitalize(translation_literals.answer) + answer_word = capitalize( + translation_literals.answer if not formulation.cot else translation_literals.answer_cot + ) + options_word = capitalize(translation_literals.options_word) query = CONTINUATION_QUERY_MCF.format( instruction=instruction, context=context, options=options, answer_word=answer_word, colon=translation_literals.colon, + options_word=options_word, + ) + + few_shot_cot = cont_input.get("few_shot_cot", None) + is_few_shot = line.get("__few_shots", False) + if formulation.cot and few_shot_cot and is_few_shot: + continuations = [ + capitalize(fix_ending_punct(answer, translation_literals)) for answer in as_list(few_shot_cot) + ] + answers = build_answers( + continuations, + formulation, + translation_literals, + is_few_shot=is_few_shot and bool(few_shot_cot), ) return Doc( diff --git a/src/lighteval/tasks/templates/copa.py b/src/lighteval/tasks/templates/copa.py index a4d82c4de..46f25201d 100644 --- a/src/lighteval/tasks/templates/copa.py +++ b/src/lighteval/tasks/templates/copa.py @@ -53,6 +53,7 @@ class COPAInput(TypedDict): continuations: list[str] gold_idx: int | list[int] instruction: NotRequired[str] + few_shot_cot: NotRequired[str] class COPAAdapter(TypedDict): @@ -71,6 +72,7 @@ class COPAAdapter(TypedDict): continuations: str gold_idx: str instruction: NotRequired[str] + few_shot_cot: NotRequired[str] def get_copa_prompt_function( @@ -113,7 +115,15 @@ def get_copa_prompt_function( """ adapter_fn = create_adapter_from_dict(adapter) continuation_prompt_fn = get_continuation_prompt_function( - language, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation + language, + { + "context": "context", + "continuations": "continuations", + "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + formulation, ) translation_literals = TRANSLATION_LITERALS[language] @@ -140,10 +150,12 @@ def copa_prompt( return continuation_prompt_fn( { + **{x: line[x] for x in line if x.startswith("__")}, "instruction": input_data.get("instruction", ""), "context": context, "continuations": input_data["continuations"], "gold_idx": input_data["gold_idx"], + "few_shot_cot": input_data.get("few_shot_cot", ""), }, task_name, ) diff --git a/src/lighteval/tasks/templates/hellaswag.py b/src/lighteval/tasks/templates/hellaswag.py index 79108a9cd..2f6f0b07d 100644 --- a/src/lighteval/tasks/templates/hellaswag.py +++ b/src/lighteval/tasks/templates/hellaswag.py @@ -49,6 +49,7 @@ class HellaswagInput(TypedDict): instruction: NotRequired[str] activity_label: NotRequired[str] ctx_b: NotRequired[str] + few_shot_cot: NotRequired[str] class HellaswagAdapter(TypedDict): @@ -58,6 +59,7 @@ class HellaswagAdapter(TypedDict): instruction: NotRequired[str] activity_label: NotRequired[str] ctx_b: NotRequired[str] + few_shot_cot: NotRequired[str] def get_hellaswag_prompt_function( @@ -107,7 +109,15 @@ def join_ctxs(ctx_a, ctx_b): adapter_fn = create_adapter_from_dict(adapter) continuation_prompt_fn = get_continuation_prompt_function( - language, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation + language, + { + "context": "context", + "continuations": "continuations", + "gold_idx": "gold_idx", + "instruction": "instruction", + "few_shot_cot": "few_shot_cot", + }, + formulation, ) def hellaswag_prompt( @@ -145,10 +155,12 @@ def hellaswag_prompt( return continuation_prompt_fn( { + **{x: line[x] for x in line if x.startswith("__")}, "instruction": input_data.get("instruction", ""), "context": full_context, "continuations": choices, "gold_idx": input_data["gold_idx"], + "few_shot_cot": input_data.get("few_shot_cot", ""), }, task_name, ) diff --git a/src/lighteval/tasks/templates/math_qa.py b/src/lighteval/tasks/templates/math_qa.py new file mode 100644 index 000000000..5bc591abf --- /dev/null +++ b/src/lighteval/tasks/templates/math_qa.py @@ -0,0 +1,87 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +from typing import Callable + +from lighteval.tasks.templates.multichoice import MCQInput, create_adapter_from_dict, get_mcq_prompt_function +from lighteval.tasks.templates.qa import QAAdapter, QAInput +from lighteval.tasks.templates.utils.formulation import CFFormulation +from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS +from lighteval.utils.language import Language + + +logger = logging.getLogger(__name__) + + +def get_math_qa_prompt_function( + language: Language, adapter: Callable[[dict], QAInput | None] | QAAdapter, cot: bool = False +): + """ + Create a templated prompt function for a QA task. + Example tasks: + - MathQA + - GSM8K + + Format: + Question: xxx + Answer: | Answer + + Args: + language (Language): The language of the QA task. + adapter (Callable[[dict], QAInput] | QAAdapter): A function or dictionary to adapt the input data to the required QAInput format. + Must map data from the dataset row to the QAInput format. + + Returns: + Callable: A function that generates QA prompts based on the given parameters. + """ + + adapter_fn = create_adapter_from_dict(adapter) + WARNED_ABOUT_INSTRUCTION = False + + def adapter_for_mcq(line: dict) -> MCQInput | None: + input_data = adapter_fn(line) + if input_data is None: + return None + + choices = input_data["choices"] + instruction = input_data.get("instruction", "") + if cot and not instruction: + translation_literals = TRANSLATION_LITERALS[language] + instruction = f"{translation_literals.qa_instruction}\n{translation_literals.math_formatting_instruction}" + nonlocal WARNED_ABOUT_INSTRUCTION + if not WARNED_ABOUT_INSTRUCTION: + logger.warning( + f"You are using Math-QA with cot, but did not provide instruction. Default to {instruction}." + ) + WARNED_ABOUT_INSTRUCTION = True + + return { + **input_data, + "gold_idx": list(range(len(choices))), + "instruction": instruction, + } + + multichoice_prompt_fn = get_mcq_prompt_function( + language, adapter=adapter_for_mcq, formulation=CFFormulation(cot=cot) + ) + return multichoice_prompt_fn diff --git a/src/lighteval/tasks/templates/multichoice.py b/src/lighteval/tasks/templates/multichoice.py index 92808488e..2b89205f1 100644 --- a/src/lighteval/tasks/templates/multichoice.py +++ b/src/lighteval/tasks/templates/multichoice.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import logging from typing import Callable from typing_extensions import NotRequired, TypedDict @@ -27,12 +28,20 @@ from lighteval.tasks.requests import Doc from lighteval.tasks.templates.utils.adapter_utils import create_adapter_from_dict from lighteval.tasks.templates.utils.formatting_utils import capitalize, fix_ending_punct -from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation, build_answers, build_choices +from lighteval.tasks.templates.utils.formulation import ( + CFFormulation, + Formulation, + MCFFormulation, + build_answers, + build_choices, +) from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS from lighteval.utils.language import Language from lighteval.utils.utils import as_list +logger = logging.getLogger(__name__) + MULTI_CHOICE_QA_QUERY = ( "{instruction}{context}{question_word}{colon}{sentence_space}{question}\n{options}{answer_word}{colon}" ) @@ -55,6 +64,7 @@ class MCQInput(TypedDict): gold_idx: list[int] | int context: NotRequired[str] instruction: NotRequired[str] + few_shot_cot: NotRequired[str] class MCQDictAdapter(TypedDict): @@ -73,6 +83,7 @@ class MCQDictAdapter(TypedDict): gold_idx: str context: NotRequired[str] instruction: NotRequired[str] + few_shot_cot: NotRequired[str] # Python too dumb to do fancy inference :( @@ -122,6 +133,8 @@ def get_mcq_prompt_function( adapter_fn = create_adapter_from_dict(adapter) + WARNED_ABOUT_COT_INSTRUCTION = False + def prompt_fn(line, task_name: str): mcq_input = adapter_fn(line) if mcq_input is None: @@ -130,19 +143,42 @@ def prompt_fn(line, task_name: str): translation_literals = TRANSLATION_LITERALS[language] instruction_val = mcq_input.get("instruction") - instruction = f"{instruction_val}\n" if instruction_val else "" + if formulation.cot and not instruction_val: + match formulation: + case MCFFormulation(choice_prefix="Letters") | MCFFormulation(choice_prefix="NativeLetters"): + instruction_val = f"{translation_literals.multichoice_mcf_instruction}\n{translation_literals.default_formatting_instruction}" + case CFFormulation(): + instruction_val = ( + f"{translation_literals.qa_instruction}\n{translation_literals.default_formatting_instruction}" + ) + case _: + raise ValueError( + "You are using a COT with a unsupported formulation. Either use CF/MCF formulation or provide an instruction." + ) + + nonlocal WARNED_ABOUT_COT_INSTRUCTION + if not WARNED_ABOUT_COT_INSTRUCTION: + logger.warning( + f"You are using a COT with CF/MCF formulation but did not provide an instruction. Defaulting to {instruction_val}" + ) + WARNED_ABOUT_COT_INSTRUCTION = True + + instruction = f"{instruction_val}\n\n" if instruction_val else "" context_val = mcq_input.get("context") context = f"{capitalize(fix_ending_punct(context_val, translation_literals))}\n" if context_val else "" question = capitalize(fix_ending_punct(mcq_input["question"], translation_literals)) - answers = [capitalize(fix_ending_punct(str(answer), translation_literals)) for answer in mcq_input["choices"]] + answers = mcq_input["choices"] + gold_idx = mcq_input["gold_idx"] + answers = [capitalize(fix_ending_punct(answer, translation_literals)) for answer in answers] options = build_choices(answers, formulation, translation_literals) options = f"{options}\n" if options else "" - answers = build_answers(answers, formulation, translation_literals) - answer_word = capitalize(translation_literals.answer) + answer_word = capitalize( + translation_literals.answer if not formulation.cot else translation_literals.answer_cot + ) question_word = capitalize(translation_literals.question_word) query = MULTI_CHOICE_QA_QUERY.format( @@ -156,10 +192,25 @@ def prompt_fn(line, task_name: str): options=options, ) + # If we are in few-shot mode, then we want to use cot-answers instead of the actual answers + # NOTE: it's important to do it after query formatting, because otherwise the options will contain cot + is_few_shot = line.get("__few_shots", False) + few_shot_cot = mcq_input.get("few_shot_cot", None) + if is_few_shot: + pass + if few_shot_cot and formulation.cot and is_few_shot: + answers = [capitalize(fix_ending_punct(answer, translation_literals)) for answer in as_list(few_shot_cot)] + gold_idx = list(range(len(answers))) + gold_idx = as_list(gold_idx) + + answers = build_answers( + answers, formulation, translation_literals, is_few_shot=is_few_shot and bool(few_shot_cot) + ) + return Doc( task_name=task_name, query=query, - gold_index=as_list(mcq_input["gold_idx"]), + gold_index=gold_idx, choices=answers, instruction=instruction_val, unconditioned_query=f"{answer_word}{translation_literals.colon}", diff --git a/src/lighteval/tasks/templates/nli.py b/src/lighteval/tasks/templates/nli.py index e8809e17b..730f4978a 100644 --- a/src/lighteval/tasks/templates/nli.py +++ b/src/lighteval/tasks/templates/nli.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import logging from typing import Callable, Literal from typing_extensions import NotRequired, TypedDict @@ -33,6 +34,9 @@ from lighteval.utils.language import Language +logger = logging.getLogger(__name__) + + NLI_TEMPLATE_QUERY_CF = "{instruction}{premise}{word_space}{confirmation_word}{question_mark}" NLI_TEMPLATE_CONT_CF = "{sentence_space}{label}{comma}{word_space}{hypothesis}" @@ -51,6 +55,7 @@ class NLIInput(TypedDict): hypothesis: str gold_idx: int instruction: NotRequired[str] + few_shot_cot: NotRequired[str] class NLIAdapter(TypedDict): @@ -67,6 +72,7 @@ class NLIAdapter(TypedDict): hypothesis: str gold_idx: str instruction: NotRequired[str] + few_shot_cot: NotRequired[str] RelationType = Literal["entailment", "neutral", "contradiction"] @@ -201,6 +207,7 @@ def get_nli_prompt_function( Callable: A function that generates NLI prompts based on the given parameters. """ # We use natural implementation for CF formulation to comply with standard evaluation formats + WARNED_ABOUT_COT_INSTRUCTION = False if isinstance(formulation, CFFormulation): return _nli_prompt_function_natural(language, adapter, relations) @@ -219,8 +226,15 @@ def get_relation_label(label: RelationType, translation_literals: TranslationLit # For hybrid we use inlined choices so we use the cf formulation in multichoice prompt fn mcq_prompt_fn = get_mcq_prompt_function( language, - {"context": "premise", "question": "hypothesis", "choices": "choices", "gold_idx": "gold_idx"}, - CFFormulation() if isinstance(formulation, HybridFormulation) else formulation, + { + "context": "premise", + "question": "hypothesis", + "choices": "choices", + "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + CFFormulation(cot=formulation.cot) if isinstance(formulation, HybridFormulation) else formulation, ) def prompt_fn(line: dict, task_name: str): @@ -234,6 +248,7 @@ def prompt_fn(line: dict, task_name: str): premise, hypothesis, gold_idx = input_data["premise"], input_data["hypothesis"], input_data["gold_idx"] premise = fix_ending_punct(capitalize(input_data["premise"]), translation_literals) hypothesis = input_data["hypothesis"] + instruction = input_data.get("instruction", "") if isinstance(formulation, HybridFormulation): # If we have the neither option move it to the end to be consistent with standard NLI evaluation rearranged_labels = labels @@ -243,15 +258,32 @@ def prompt_fn(line: dict, task_name: str): choices_str = f"{translation_literals.comma}{translation_literals.word_space}".join(rearranged_labels[:-1]) hypothesis = f"{hypothesis.rstrip(PUNCT)}{translation_literals.sentence_space}{choices_str}{translation_literals.word_space}{translation_literals.or_word}{translation_literals.word_space}{rearranged_labels[-1]}{translation_literals.question_mark}" + elif isinstance(formulation, MCFFormulation): + if formulation.cot and not instruction: + match formulation: + case MCFFormulation(choice_prefix="Letters") | MCFFormulation(choice_prefix="NativeLetters"): + instruction = f"{translation_literals.nli_mcf_instruction}\n{translation_literals.default_formatting_instruction}" + nonlocal WARNED_ABOUT_COT_INSTRUCTION + if not WARNED_ABOUT_COT_INSTRUCTION: + logger.warning( + f"You are using a COT with MCF formulation but did not provide an instruction. Defaulting to {instruction}" + ) + WARNED_ABOUT_COT_INSTRUCTION = True + case _: + raise ValueError( + "You are using a COT with a unsupported formulation. Either use MCF formulation or provide an instruction." + ) # (hynky1999): Ideally we would not compute logprobs of the Yes/No/Also in CF formulation. However as of right now lighteval doesn't allow to # use multi-context. row = { - "instruction": input_data.get("instruction", ""), + **{x: line[x] for x in line if x.startswith("__")}, + "instruction": instruction, "premise": premise, "hypothesis": hypothesis, "gold_idx": gold_idx, "choices": labels, + "few_shot_cot": input_data.get("few_shot_cot", ""), } return mcq_prompt_fn(row, task_name) diff --git a/src/lighteval/tasks/templates/qa.py b/src/lighteval/tasks/templates/qa.py index e798f820d..0a1509e4b 100644 --- a/src/lighteval/tasks/templates/qa.py +++ b/src/lighteval/tasks/templates/qa.py @@ -34,6 +34,7 @@ class QAInput(TypedDict): choices: list[str] context: NotRequired[str] instruction: NotRequired[str] + few_shot_cot: NotRequired[str] class QAAdapter(TypedDict): @@ -41,9 +42,12 @@ class QAAdapter(TypedDict): context: str context: NotRequired[str] instruction: NotRequired[str] + few_shot_cot: NotRequired[str] -def get_qa_prompt_function(language: Language, adapter: Callable[[dict], QAInput | None] | QAAdapter): +def get_qa_prompt_function( + language: Language, adapter: Callable[[dict], QAInput | None] | QAAdapter, cot: bool = False +): """ Create a templated prompt function for a QA task. Example tasks: @@ -70,13 +74,14 @@ def adapter_for_mcq(line: dict) -> MCQInput | None: if input_data is None: return None - choices = list(set(input_data["choices"])) + choices = input_data["choices"] return { **input_data, "gold_idx": list(range(len(choices))), - "choices": choices, } - multichoice_prompt_fn = get_mcq_prompt_function(language, adapter=adapter_for_mcq, formulation=CFFormulation()) + multichoice_prompt_fn = get_mcq_prompt_function( + language, adapter=adapter_for_mcq, formulation=CFFormulation(cot=cot) + ) return multichoice_prompt_fn diff --git a/src/lighteval/tasks/templates/translation.py b/src/lighteval/tasks/templates/translation.py index c90b99e01..ee9e42058 100644 --- a/src/lighteval/tasks/templates/translation.py +++ b/src/lighteval/tasks/templates/translation.py @@ -20,24 +20,29 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import logging from typing import Callable +from langcodes import Language as LangCodeLanguage from langcodes import standardize_tag from typing_extensions import NotRequired, TypedDict from lighteval.tasks.templates.continuation import get_continuation_prompt_function from lighteval.tasks.templates.multichoice import create_adapter_from_dict from lighteval.tasks.templates.utils.formatting_utils import capitalize, fix_ending_punct -from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation +from lighteval.tasks.templates.utils.formulation import CFFormulation, Formulation, MCFFormulation from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS from lighteval.utils.language import Language from lighteval.utils.utils import as_list +logger = logging.getLogger(__name__) + # Template chosen so that it's not very language-dependent, as it's not clear whether one should use the target or source language. # It's also the best template based on https://arxiv.org/pdf/2301.07069. +TRANSLATION_INSTRUCTION = "Translate the following text from {source_language} to {target_language}." TRANSLATION_CONTEXT = "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}" @@ -55,6 +60,7 @@ class TranslationInput(TypedDict): target_text: str | list[str] gold_idx: NotRequired[int | list[int]] instruction: NotRequired[str] + few_shot_cot: NotRequired[str] class TranslationAdapter(TypedDict): @@ -68,8 +74,9 @@ class TranslationAdapter(TypedDict): source_text: str target_text: str - gold_idx: NotRequired[int | list[int]] + gold_idx: NotRequired[str] instruction: NotRequired[str] + few_shot_cot: NotRequired[str] def get_translation_prompt_function( @@ -110,7 +117,13 @@ def get_translation_prompt_function( adapter_fn = create_adapter_from_dict(adapter) continuation_prompt_fn = get_continuation_prompt_function( Language.ENGLISH, - {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, + { + "context": "context", + "continuations": "continuations", + "gold_idx": "gold_idx", + "instruction": "instruction", + "few_shot_cot": "few_shot_cot", + }, formulation, fix_formatting=False, ) @@ -119,6 +132,10 @@ def get_translation_prompt_function( source_label_string = standardize_tag(source_language.value).upper() target_label_string = standardize_tag(target_language.value).upper() + source_language_display_name = LangCodeLanguage.get(source_language.value).display_name() + target_language_display_name = LangCodeLanguage.get(target_language.value).display_name() + + WARNED_ABOUT_COT_INSTRUCTION = False def translation_prompt( line: dict, @@ -143,12 +160,41 @@ def translation_prompt( for text in as_list(input_data["target_text"]) ] + # Handle instruction + instruction_val = input_data.get("instruction") + if formulation.cot and not instruction_val: + match formulation: + case CFFormulation(): + translation_instruction = TRANSLATION_INSTRUCTION.format( + source_language=source_language_display_name, target_language=target_language_display_name + ) + instruction_val = ( + f"{translation_instruction}\n{source_translation_literals.default_formatting_instruction}" + ) + case MCFFormulation(choice_prefix="NativeLetters") | MCFFormulation(choice_prefix="Letters"): + instruction_val = f"{source_translation_literals.multichoice_mcf_instruction}\n{source_translation_literals.default_formatting_instruction}" + case _: + raise ValueError( + "You are using a COT with a unsupported formulation. Either use CF/MCF formulation or provide an instruction." + ) + + nonlocal WARNED_ABOUT_COT_INSTRUCTION + if not WARNED_ABOUT_COT_INSTRUCTION: + logger.warning( + f" You are using a COT with MCF formulation but did not provide an instruction. Defaulting to {instruction_val}" + ) + WARNED_ABOUT_COT_INSTRUCTION = True + + instruction = f"{instruction_val}" if instruction_val else "" + return continuation_prompt_fn( { - "instruction": input_data.get("instruction", ""), + **{x: line[x] for x in line if x.startswith("__")}, + "instruction": instruction, "context": context, "continuations": continuations, "gold_idx": input_data.get("gold_idx", list(range(len(continuations)))), + "few_shot_cot": input_data.get("few_shot_cot", ""), }, task_name, ) diff --git a/src/lighteval/tasks/templates/utils/formulation.py b/src/lighteval/tasks/templates/utils/formulation.py index 6447e01dc..4cffb54ef 100644 --- a/src/lighteval/tasks/templates/utils/formulation.py +++ b/src/lighteval/tasks/templates/utils/formulation.py @@ -39,10 +39,12 @@ class MCFFormulation: Args: choice_prefix (ChoicePrefix, optional): The choice prefix to use for the task. Defaults to "Letters". + cot (bool, optional): Whether to use COT for the task. Defaults to False. """ choice_prefix: ChoicePrefix = "Letters" name: str = "MCF" + cot: bool = False @dataclass @@ -58,6 +60,7 @@ class HybridFormulation: choice_prefix: ChoicePrefix = "Letters" name: str = "Hybrid" + cot: bool = False @dataclass @@ -68,6 +71,7 @@ class CFFormulation: """ name: str = "CF" + cot: bool = False Formulation = CFFormulation | HybridFormulation | MCFFormulation @@ -131,6 +135,7 @@ def build_answers( formulation: Formulation, translation_literals: TranslationLiterals, use_sentence_space: bool = True, + is_few_shot: bool = False, ) -> list[str]: """ Builds a string version of the answers based on passed formulation. @@ -144,7 +149,7 @@ def build_answers( use_sentence_space (bool, optional): Whether to use sentence or word space in front of the answer. Defaults to True. The same value should be passed to `build_choices` function to ensure consistent tokenization. """ - if isinstance(formulation, MCFFormulation): + if isinstance(formulation, MCFFormulation) and not (formulation.cot and is_few_shot): prefixes = get_prefix(formulation.choice_prefix, translation_literals) answers = [prefixes[i] for i in range(len(answers))] diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index 740475cb7..fcbe09538 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -35,6 +35,8 @@ class TranslationLiterals: question_word: str = None # type: ignore answer: str = None # type: ignore + answer_cot: str = None # type: ignore + options_word: str = "options" # type: ignore confirmation_word: str = None # type: ignore yes: str = None # type: ignore no: str = None # type: ignore @@ -50,7 +52,7 @@ class TranslationLiterals: neither: str = None # type: ignore # Punctuation - full_stop: str = "." + full_stop: str = "." # Separating sequence comma: str = "," question_mark: str = "?" exclamation_mark: str = "!" @@ -62,6 +64,17 @@ class TranslationLiterals: # Indices indices: list[str] = field(default_factory=lambda: LETTER_INDICES) + # Instructions + continuation_mcf_instruction: str = "Choose the letter of the most likely continuation." + nli_mcf_instruction: str = "Choose the letter of the most likely relation between the premise and hypothesis." + multichoice_mcf_instruction: str = "Choose the letter of the correct answer." + qa_instruction: str = "Answer the following question." + # Formatting instruction + # This format seems to work quite well across models. We don't put the answer in tags because sometimes models will repeat + # the answer resulting in answer ACTUAL ANSWER + default_formatting_instruction: str = "Output the final answer in format: ." + math_formatting_instruction: str = "Output the answer in \\boxed{}." + def __getattribute__(self, name: str) -> str: value = super().__getattribute__(name) if value is None: @@ -84,6 +97,8 @@ def __getattribute__(self, name: str) -> str: language=Language.ARABIC, question_word="سؤال", answer="إجابة", + answer_cot="الإجابة خطوة بخطوة", + options_word="خيارات", confirmation_word="صحيح", yes="نعم", no="لا", @@ -103,6 +118,13 @@ def __getattribute__(self, name: str) -> str: sentence_space=" ", colon=":", indices=["أ", "ب", "ج", "د", "هـ", "و", "ز", "ح"], + # Translated using gpt4-o + continuation_mcf_instruction="اختر الحرف الذي يمثل الاستمرار الأكثر احتمالاً", + nli_mcf_instruction="اختر حرف العلاقة الأكثر احتمالا بين المقدمة والفرضية", + qa_instruction="أجب عن السؤال التالي", + multichoice_mcf_instruction="اختر الحرف الذي يمثل الإجابة الصحيحة", + default_formatting_instruction="إخراج الإجابة النهائية بالتنسيق: .", + math_formatting_instruction="اكتب الإجابة في \\boxed{}", ), Language.ARMENIAN: TranslationLiterals(language=Language.ARMENIAN), Language.ASSAMESE: TranslationLiterals(language=Language.ASSAMESE), @@ -224,6 +246,8 @@ def __getattribute__(self, name: str) -> str: language=Language.CHINESE, question_word="问题", answer="答案", + answer_cot="逐步解答", + options_word="选项", confirmation_word="对吗", yes="是的", no="不是", @@ -243,6 +267,13 @@ def __getattribute__(self, name: str) -> str: sentence_space="", colon=":", indices=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"], + # Translated using gpt4-o + continuation_mcf_instruction="选择最可能的继续的字母。", + nli_mcf_instruction="选择前提和假设之间最可能的关系的字母。", + qa_instruction="回答以下问题。", + multichoice_mcf_instruction="选择正确答案的字母。", + default_formatting_instruction="以格式输出最终答案。", + math_formatting_instruction="在 \\boxed{} 环境中输出答案。", ), Language.CHOKWE: TranslationLiterals(language=Language.CHOKWE), Language.CRIMEAN_TATAR: TranslationLiterals(language=Language.CRIMEAN_TATAR), @@ -327,6 +358,7 @@ def __getattribute__(self, name: str) -> str: language=Language.ENGLISH, question_word="question", answer="answer", + answer_cot="Step-by-Step Answer", confirmation_word="right", yes="yes", no="no", @@ -384,7 +416,9 @@ def __getattribute__(self, name: str) -> str: language=Language.FRENCH, question_word="question", answer="réponse", + answer_cot="Réponse étape par étape", confirmation_word="n'est-ce pas", + options_word="options", yes="oui", no="non", also="de plus", @@ -402,6 +436,13 @@ def __getattribute__(self, name: str) -> str: word_space=" ", sentence_space=" ", colon=":", + continuation_mcf_instruction="Choisissez la lettre de la continuation la plus probable.", + nli_mcf_instruction="Choisissez la lettre de la relation la plus probable entre la prémisse et l’hypothèse.", + qa_instruction="Répondez à la question suivante.", + multichoice_mcf_instruction="Choisissez la lettre de la réponse correcte.", + # Formatting instruction + default_formatting_instruction="Affichez la réponse finale au format: .", + math_formatting_instruction="Donnez la réponse dans \\boxed{}.", ), Language.FRIULIAN: TranslationLiterals(language=Language.FRIULIAN), Language.GALICIAN: TranslationLiterals( @@ -491,6 +532,8 @@ def __getattribute__(self, name: str) -> str: language=Language.HINDI, question_word="सवाल", answer="उत्तर", + answer_cot="चरण दर चरण उत्तर", + options_word="विकल्प", confirmation_word="है ना", yes="हाँ", no="नहीं", @@ -510,6 +553,13 @@ def __getattribute__(self, name: str) -> str: sentence_space=" ", colon=":", indices=["क", "ख", "ग", "घ", "ङ", "च"], + continuation_mcf_instruction="अत्यधिक संभावित निरंतरता का अक्षर चुनें।", + nli_mcf_instruction="आधार और परिकल्पना के बीच सबसे संभावित संबंध का अक्षर चुनें।", + qa_instruction="निम्नलिखित प्रश्न का उत्तर दें।", + multichoice_mcf_instruction="सही उत्तर का अक्षर चुनें।", + # Formatting instruction + default_formatting_instruction="अंतिम उत्तर को इस प्रारूप में आउटपुट करें: ।", + math_formatting_instruction="उत्तर को \\boxed{} में आउटपुट करें।", ), Language.HUNGARIAN: TranslationLiterals( language=Language.HUNGARIAN, @@ -760,6 +810,8 @@ def __getattribute__(self, name: str) -> str: language=Language.RUSSIAN, question_word="вопрос", answer="ответ", + answer_cot="Пошаговое решение", + options_word="варианты", confirmation_word="верно", yes="да", no="нет", @@ -778,6 +830,13 @@ def __getattribute__(self, name: str) -> str: sentence_space=" ", colon=":", indices=["А", "Б", "В", "Г", "Д", "Е"], + continuation_mcf_instruction="Выберите букву наиболее вероятного продолжения.", + nli_mcf_instruction="Выберите букву наиболее вероятной связи между предпосылкой и гипотезой.", + qa_instruction="Ответьте на следующий вопрос.", + multichoice_mcf_instruction="Выберите букву правильного ответа.", + # Formatting instruction + default_formatting_instruction="Выведите окончательный ответ в формате: .", + math_formatting_instruction="Выведите ответ в \\boxed{}.", ), Language.SAMOAN: TranslationLiterals(language=Language.SAMOAN), Language.SANGO: TranslationLiterals(language=Language.SANGO), @@ -914,7 +973,9 @@ def __getattribute__(self, name: str) -> str: language=Language.SWAHILI, question_word="swali", answer="jibu", + answer_cot="jibu la Hatua kwa Hatua", confirmation_word="sahihi", + options_word="chaguo", yes="ndiyo", no="hapana", also="pia", @@ -931,6 +992,14 @@ def __getattribute__(self, name: str) -> str: word_space=" ", sentence_space=" ", colon=":", + # Translated using gpt-4o + continuation_mcf_instruction="Chagua herufi ya mwendelezo unaowezekana zaidi.", + nli_mcf_instruction="Chagua herufi ya uhusiano unaowezekana kati ya dhana na dhana.", + qa_instruction="Jibu swali lifuatalo.", + multichoice_mcf_instruction="Chagua herufi ya jibu sahihi.", + # Formatting instruction + default_formatting_instruction="Toa jibu la mwisho katika umbizo: .", + math_formatting_instruction="Toa jibu katika \\boxed{}.", ), Language.SWATI: TranslationLiterals(language=Language.SWATI), Language.SWEDISH: TranslationLiterals( @@ -993,6 +1062,8 @@ def __getattribute__(self, name: str) -> str: language=Language.TELUGU, question_word="ప్రశ్న", answer="జవాబు", + answer_cot="దశలవారీగా సమాధానం", + options_word="ఎంపికలు", confirmation_word="కదా", yes="అవును", no="కాదు", @@ -1010,12 +1081,21 @@ def __getattribute__(self, name: str) -> str: word_space=" ", sentence_space=" ", colon=":", - indices=["ఎ", "బి", "సి", "డి", "ఇ"], + indices=["అ", "ఆ", "ఇ", "ఈ", "ఉ", "ఊ"], + continuation_mcf_instruction="అత్యంత సాధ్యమైన కొనసాగింపును సూచించే అక్షరాన్ని ఎంచుకోండి.", + nli_mcf_instruction="పూర్వాపరం మరియు పరికల్పన మధ్య అత్యంత సంభావ్య సంబంధం యొక్క అక్షరాన్ని ఎంచుకోండి.", + qa_instruction="క్రింది ప్రశ్నకు సమాధానం ఇవ్వండి.", + multichoice_mcf_instruction="సరైన సమాధానాన్ని సూచించే అక్షరాన్ని ఎంచుకోండి.", + # Formatting instruction + default_formatting_instruction="తుది సమాధానాన్ని ఈ ఫార్మాట్‌లో అవుట్‌పుట్ చేయండి: .", + math_formatting_instruction="సమాధానాన్ని \\boxed{} లో ఇవ్వండి.", ), Language.THAI: TranslationLiterals( language=Language.THAI, question_word="คำถาม", answer="คำตอบ", + answer_cot="คำตอบทีละขั้นตอน", + options_word="ตัวเลือก", confirmation_word="ใช่ไหม", yes="ใช่", no="ไม่", @@ -1027,13 +1107,19 @@ def __getattribute__(self, name: str) -> str: neither="ไม่ใช่ทั้งสองอย่าง", or_word="หรือ", full_stop=".", - comma=",", question_mark="?", exclamation_mark="!", word_space="", sentence_space=" ", colon=":", - indices=["๑", "๒", "๓", "๔", "๕", "๖", "๗", "๘", "๙", "๐"], + indices=["ก", "ข", "ค", "ง", "จ", "ฉ", "ช", "ซ"], + continuation_mcf_instruction="เลือกตัวอักษรของการดำเนินการต่อที่มีความเป็นไปได้มากที่สุด", + nli_mcf_instruction="เลือกตัวอักษรที่มีความสัมพันธ์ที่เป็นไปได้มากที่สุดระหว่างข้อตั้งและสมมติฐาน", + qa_instruction="ตอบคำถามต่อไปนี้", + multichoice_mcf_instruction="เลือกตัวอักษรของคำตอบที่ถูกต้อง", + # Formatting instruction + default_formatting_instruction="ส่งออกคำตอบสุดท้ายในรูปแบบ: ", + math_formatting_instruction="แสดงคำตอบใน \\boxed{}", ), Language.TIGRINYA: TranslationLiterals(language=Language.TIGRINYA), Language.TOK_PISIN: TranslationLiterals(language=Language.TOK_PISIN), @@ -1046,7 +1132,9 @@ def __getattribute__(self, name: str) -> str: language=Language.TURKISH, question_word="soru", answer="cevap", + answer_cot="adım adım cevap", confirmation_word="değil mi", + options_word="seçenekler", yes="evet", no="hayır", also="ayrıca", @@ -1064,6 +1152,13 @@ def __getattribute__(self, name: str) -> str: word_space=" ", sentence_space=" ", colon=":", + continuation_mcf_instruction="En olası devamı temsil eden harfi seçin.", + nli_mcf_instruction="Öncül ile hipotez arasındaki olası ilişkinin harfini seçiniz.", + qa_instruction="Aşağıdaki soruyu yanıtlayın.", + multichoice_mcf_instruction="Doğru cevabı temsil eden harfi seçin.", + # Formatting instruction + default_formatting_instruction="Son cevabı şu formatta çıktı olarak verin: .", + math_formatting_instruction="Cevabı \\boxed{} içinde verin.", ), Language.TURKMEN: TranslationLiterals(language=Language.TURKMEN), Language.TWI: TranslationLiterals(language=Language.TWI), diff --git a/tests/metrics/test_extractive_match.py b/tests/metrics/test_extractive_match.py index c3a12c813..2b6b5efc6 100644 --- a/tests/metrics/test_extractive_match.py +++ b/tests/metrics/test_extractive_match.py @@ -44,7 +44,7 @@ def compare_strings( gold: str, pred: str, language: Language = Language.ENGLISH, - match_types: list[str] = ["latex", "expr"], + match_types: list[str] = ["latex", "expr", "NativeLetters"], precision: int = 6, ): """Helper function to compare strings using the math extraction metrics""" @@ -56,7 +56,9 @@ def compare_strings( elif match_type == "expr": extraction_targets.append(ExprExtractionConfig()) elif match_type == "NativeLetters": - extraction_targets.append(IndicesExtractionConfig(prefix_for_extraction="NativeLetters")) + extraction_targets.append( + IndicesExtractionConfig(prefix_for_extraction="NativeLetters", bb_match_priority=0) + ) extraction_targets = tuple(extraction_targets) # Convert to tuple @@ -104,6 +106,7 @@ def test_extraction_abc(gold, pred, expected): ("B", "B。 不是 A", Language.CHINESE, 1), ("B", "B。不是 A", Language.CHINESE, 1), ("B", "B不是 A", Language.CHINESE, 1), + ("B", "Hmm I am not sure it's A ?? Not it's B. Surely not D", Language.CHINESE, 1), ], ) def test_multilingual_extraction_abc(gold, pred, language, expected): diff --git a/tests/metrics/test_multilingual_metrics.py b/tests/metrics/test_multilingual_metrics.py new file mode 100644 index 000000000..c7eabe8a4 --- /dev/null +++ b/tests/metrics/test_multilingual_metrics.py @@ -0,0 +1,145 @@ +from lighteval.metrics.dynamic_metrics import ( + multilingual_quasi_exact_match_metric, + multilingual_quasi_f1_score_metric, +) +from lighteval.utils.language import Language + + +def test_multilingual_quasi_exact_match_happy_path(): + """Test basic functionality of exact match metric""" + metric = multilingual_quasi_exact_match_metric(language=Language.ENGLISH) + + # Test exact match + result = metric.sample_level_fn( + golds=["hello world"], + predictions=["hello world"], + ) + assert result == 1 + + # Test with different spacing/punctuation + result = metric.sample_level_fn( + golds=["hello world"], + predictions=["hello, world!"], + ) + assert result == 1 + + # Test with no match + result = metric.sample_level_fn( + golds=["hello world"], + predictions=["goodbye world"], + ) + assert result == 0 + + +def test_multilingual_quasi_exact_match_bb_extraction(): + """Test bold text extraction functionality""" + metric = multilingual_quasi_exact_match_metric(language=Language.ENGLISH, extract_bb=True) + + # Test with single bold tag + result = metric.sample_level_fn( + golds=["answer"], + predictions=["The correct answer is answer"], + ) + assert result == 1 + + # Test with multiple bold tags - should take last one + result = metric.sample_level_fn( + golds=["final answer"], + predictions=["First wrong then final answer"], + ) + assert result == 1 + + # Test with no bold tags - should use full text + result = metric.sample_level_fn( + golds=["answer"], + predictions=["answer"], + ) + assert result == 1 + + # Test with empty bold tags + result = metric.sample_level_fn( + golds=["answer"], + predictions=[" answer"], + ) + assert result == 0 + + +def test_multilingual_quasi_f1_score_happy_path(): + """Test basic functionality of F1 score metric""" + metric = multilingual_quasi_f1_score_metric(language=Language.ENGLISH) + + # Test perfect match + result = metric.sample_level_fn( + golds=["hello world"], + predictions=["hello world"], + ) + assert result == 1 + + # Test partial match + result = metric.sample_level_fn( + golds=["hello beautiful world"], + predictions=["hello world"], + ) + assert result > 0 and result < 1 + + # Test no match + result = metric.sample_level_fn( + golds=["hello world"], + predictions=["goodbye moon"], + ) + assert result == 0 + + +def test_multilingual_quasi_f1_score_bb_extraction(): + """Test bold text extraction functionality with F1 score""" + metric = multilingual_quasi_f1_score_metric(language=Language.ENGLISH, extract_bb=True) + + # Test with single bold tag + result = metric.sample_level_fn( + golds=["answer key"], + predictions=["The correct answer is answer key"], + ) + assert result == 1 + + # Test with multiple bold tags - should take last one + result = metric.sample_level_fn( + golds=["final answer"], + predictions=["First wrong then final answer"], + ) + assert result == 1 + + # Test with partial match in bold + result = metric.sample_level_fn( + golds=["complete answer key"], + predictions=["The text contains answer key"], + ) + assert result > 0 and result < 1 + + # Test with no bold tags - should use full text + result = metric.sample_level_fn( + golds=["answer"], + predictions=["answer"], + ) + assert result == 1 + + +def test_multilingual_support(): + """Test metrics work with different languages""" + languages = [Language.ENGLISH, Language.FRENCH, Language.CHINESE] + + for lang in languages: + # Test exact match + em_metric = multilingual_quasi_exact_match_metric(language=lang) + result = em_metric.sample_level_fn( + golds=["test"], + predictions=["test"], + ) + assert result == 1 + + # Test F1 score + f1_metric = multilingual_quasi_f1_score_metric(language=lang) + result = f1_metric.sample_level_fn( + golds=["test"], + predictions=["test"], + ) + assert result == 1 diff --git a/tests/tasks/templates/test_continuation.py b/tests/tasks/templates/test_continuation.py index 9681ecd66..f09a3e1d3 100644 --- a/tests/tasks/templates/test_continuation.py +++ b/tests/tasks/templates/test_continuation.py @@ -46,6 +46,8 @@ def test_continuation_prompt_mcf(): doc.query == """\ The quick brown fox + +Options: A. jumps over the lazy dog B. runs through the forest C. chases a rabbit @@ -131,7 +133,10 @@ def test_continuation_optional_keys(): doc.query == """\ Choose the most likely continuation: + In the morning, I like to + +Options: A. drink coffee B. go for a run C. read the news @@ -142,3 +147,88 @@ def test_continuation_optional_keys(): assert doc.unconditioned_query == "Answer:" assert doc.choices == [" A", " B", " C"] assert doc.gold_index == [0] + + +def test_continuation_prompt_mcf_cot(): + """Test multiple-choice format continuation prompt generation with cot.""" + test_input = { + "context": "The quick brown fox", + "continuations": ["jumps over the lazy dog", "Runs through the forest", "Chases a rabbit"], + "gold_idx": 0, + "__few_shots": True, + "few_shot_cot": "i think it's A. jumps over the lazy dog", + "instruction": "Choose the letter of the most likely continuation.", + } + + prompt_fn = get_continuation_prompt_function( + Language.ENGLISH, + { + "context": "context", + "continuations": "continuations", + "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_continuation_task") + + # We expect the contuation to be decapitalized as it's continuation of non-ended sentence + assert ( + doc.query + == """\ +Choose the letter of the most likely continuation. + +The quick brown fox + +Options: + A. jumps over the lazy dog + B. runs through the forest + C. chases a rabbit +Step-by-Step Answer:\ +""" + ) + + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" I think it's A. jumps over the lazy dog"] + assert doc.gold_index == [0] + + +def test_continuation_default_instruction_mcf(): + """Test default instruction for MCF continuation prompt.""" + test_input = { + "context": "The quick brown fox", + "continuations": ["jumps over the lazy dog", "Runs through the forest", "Chases a rabbit"], + "gold_idx": 0, + } + + # Note: "instruction" key is NOT in the key_map + prompt_fn = get_continuation_prompt_function( + Language.ENGLISH, + {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_continuation_task_default_mcf") + + expected_instruction = ( + "Choose the letter of the most likely continuation.\nOutput the final answer in format: ." + ) + assert ( + doc.query + == f"""\ +{expected_instruction} + +The quick brown fox + +Options: + A. jumps over the lazy dog + B. runs through the forest + C. chases a rabbit +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" A", " B", " C"] + assert doc.gold_index == [0] diff --git a/tests/tasks/templates/test_copa.py b/tests/tasks/templates/test_copa.py index 775fb37fb..9f7f40ffe 100644 --- a/tests/tasks/templates/test_copa.py +++ b/tests/tasks/templates/test_copa.py @@ -23,7 +23,7 @@ import pytest from lighteval.tasks.templates.copa import get_copa_prompt_function -from lighteval.tasks.templates.utils.formulation import CFFormulation +from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation from lighteval.utils.language import Language @@ -60,3 +60,55 @@ def test_copa_prompt_cf(cause_effect): assert doc.unconditioned_query == "" assert doc.choices == [" he has big muscles", " he is weak"] assert doc.gold_index == [0] + + +@pytest.mark.parametrize("cause_effect", ["cause", "effect"]) +def test_copa_prompt_mcf_cot(cause_effect): + """ + Tests that copa prompt function works correctly for both cause/effect. + Since it's pretty much a wrapper around continuation template we just test single formulation. + + """ + test_input = { + "cause_effect": cause_effect, + "context": "He is strong", + "continuations": ["he has big muscles", "he is weak"], + "gold_idx": 0, + "__few_shots": True, + "few_shot_cot": "i think it's A. he has big muscles", + "instruction": "Choose the letter of the most likely continuation.", + } + + prompt_fn = get_copa_prompt_function( + Language.ENGLISH, + { + "cause_effect": "cause_effect", + "context": "context", + "continuations": "continuations", + "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task") + + cause_effect_word = "because" if cause_effect == "cause" else "therefore" + assert ( + doc.query + == f"""\ +Choose the letter of the most likely continuation. + +He is strong {cause_effect_word} + +Options: + A. he has big muscles + B. he is weak +Step-by-Step Answer:\ +""" + ) + + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" I think it's A. he has big muscles"] + assert doc.gold_index == [0] diff --git a/tests/tasks/templates/test_hellaswag.py b/tests/tasks/templates/test_hellaswag.py index 2ef7b895b..39b8cb78c 100644 --- a/tests/tasks/templates/test_hellaswag.py +++ b/tests/tasks/templates/test_hellaswag.py @@ -91,6 +91,8 @@ def test_hellaswag_prompt_mcf(): doc.query == """\ Fitness:\nHe is strong he is fast + +Options: A. he has big muscles B. he is weak Answer:\ @@ -158,3 +160,102 @@ def test_hellaswag_single_ctx(): doc = prompt_fn(test_input, "test_task") assert doc.query == "Fitness:\nHe is strong." + + +def test_hellaswag_prompt_mcf_cot(): + """ + Tests that hellaswag prompt function works correctly. + Since it's pretty much a wrapper around continuation template we just test single formulation. + + """ + test_input = { + "activity_label": "fitness", + "ctx_a": "He is strong", + "ctx_b": "He is fast", + "continuations": ["he has big muscles", "he is weak"], + "gold_idx": 0, + "__few_shots": True, + "few_shot_cot": "i think it's A. he has big muscles", + "instruction": "Choose the letter of the most likely continuation.", + } + + prompt_fn = get_hellaswag_prompt_function( + Language.ENGLISH, + { + "activity_label": "activity_label", + "continuations": "continuations", + "gold_idx": "gold_idx", + "ctx_a": "ctx_a", + "ctx_b": "ctx_b", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task") + assert ( + doc.query + == """\ +Choose the letter of the most likely continuation. + +Fitness: +He is strong he is fast + +Options: + A. he has big muscles + B. he is weak +Step-by-Step Answer:\ +""" + ) + + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" I think it's A. he has big muscles"] + assert doc.gold_index == [0] + + +def test_hellaswag_default_instruction_mcf(): + """Test default instruction for MCF hellaswag prompt.""" + test_input = { + "activity_label": "fitness", + "ctx_a": "He is strong", + "ctx_b": "He is fast", + "continuations": ["he has big muscles", "he is weak"], + "gold_idx": 0, + } + + # Note: "instruction" key is NOT in the key_map + prompt_fn = get_hellaswag_prompt_function( + Language.ENGLISH, + { + "activity_label": "activity_label", + "continuations": "continuations", + "gold_idx": "gold_idx", + "ctx_a": "ctx_a", + "ctx_b": "ctx_b", + }, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task_default_mcf") + + expected_instruction = ( + "Choose the letter of the most likely continuation.\nOutput the final answer in format: ." + ) + assert ( + doc.query + == f"""\ +{expected_instruction} + +Fitness: +He is strong he is fast + +Options: + A. he has big muscles + B. he is weak +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" A", " B"] + assert doc.gold_index == [0] diff --git a/tests/tasks/templates/test_math_qa.py b/tests/tasks/templates/test_math_qa.py new file mode 100644 index 000000000..ef0a48766 --- /dev/null +++ b/tests/tasks/templates/test_math_qa.py @@ -0,0 +1,66 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import logging + +from lighteval.tasks.templates.math_qa import get_math_qa_prompt_function +from lighteval.tasks.templates.qa import QAInput +from lighteval.utils.language import Language + + +logger = logging.getLogger(__name__) + + +def test_math_qa_prompt_cf_cot_default_instruction(): + """ + Tests Math QA with CoT and default instruction. + """ + test_input = { + "question": "Solve for x: x + 5 = 10", + "choices": ["5"], + } + + prompt_fn = get_math_qa_prompt_function( + language=Language.ENGLISH, + adapter=lambda x: QAInput( + question=x["question"], + choices=x["choices"], + ), + cot=True, + ) + + doc = prompt_fn(test_input, "test_task") + + assert ( + doc.query + == """\ +Answer the following question. +Output the answer in \\boxed{}. + +Question: Solve for x: x + 5 = 10 +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" 5"] + assert doc.gold_index == [0] diff --git a/tests/tasks/templates/test_multichoice.py b/tests/tasks/templates/test_multichoice.py index 9b3614ed2..b0d4cd577 100644 --- a/tests/tasks/templates/test_multichoice.py +++ b/tests/tasks/templates/test_multichoice.py @@ -177,6 +177,7 @@ def test_multichoice_optional_keys(): doc.query == """\ Please answer the following question about geography. + France is big. Question: What is the capital of France? A. London @@ -186,3 +187,123 @@ def test_multichoice_optional_keys(): Answer:\ """ ) + + +def test_multichoice_prompt_mcf_cot(): + """Test multiple-choice format (MCF) with COT prompt generation for multichoice questions.""" + test_input = { + "question": "What is the capital of France?", + "choices": ["London", "Paris", "Berlin", "Madrid"], + "gold_idx": 1, + "instruction": "Please answer the following question about geography.", + "__few_shots": True, + "few_shot_cot": "i think it's D", + } + + prompt_fn = get_mcq_prompt_function( + Language.ENGLISH, + { + "question": "question", + "choices": "choices", + "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task") + pass + + assert ( + doc.query + == """\ +Please answer the following question about geography. + +Question: What is the capital of France? + A. London + B. Paris + C. Berlin + D. Madrid +Step-by-Step Answer:\ +""" + ) + + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" I think it's D"] + + +def test_multichoice_default_instruction_mcf(): + """Test default instruction for MCF multichoice prompt.""" + test_input = { + "question": "What is the capital of France?", + "choices": ["London", "Paris", "Berlin", "Madrid"], + "gold_idx": 1, + } + + # Note: "instruction" key is NOT in the key_map + prompt_fn = get_mcq_prompt_function( + Language.ENGLISH, + { + "question": "question", + "choices": "choices", + "gold_idx": "gold_idx", + }, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task_default_mcf") + + # Default instruction from TranslationLiterals.ENGLISH.multichoice_instruction + expected_instruction = "Choose the letter of the correct answer.\nOutput the final answer in format: ." + expected_query = f"""\ +{expected_instruction} + +Question: What is the capital of France? + A. London + B. Paris + C. Berlin + D. Madrid +Step-by-Step Answer:\ +""" + assert doc.query == expected_query + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" A", " B", " C", " D"] + assert doc.gold_index == [1] + + +def test_multichoice_default_instruction_cf(): + """Test default instruction for CF multichoice prompt.""" + test_input = { + "question": "What is the capital of France?", + "choices": ["London", "Paris", "Berlin", "Madrid"], + "gold_idx": 1, + } + + # Note: "instruction" key is NOT in the key_map + prompt_fn = get_mcq_prompt_function( + Language.ENGLISH, + { + "question": "question", + "choices": "choices", + "gold_idx": "gold_idx", + }, + CFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task_default_cf") + + # Default instruction from TranslationLiterals.ENGLISH.multichoice_instruction + expected_instruction = "Answer the following question.\nOutput the final answer in format: ." + assert ( + doc.query + == f"""\ +{expected_instruction} + +Question: What is the capital of France? +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" London", " Paris", " Berlin", " Madrid"] + assert doc.gold_index == [1] diff --git a/tests/tasks/templates/test_nli.py b/tests/tasks/templates/test_nli.py index a634432b9..193f9cce6 100644 --- a/tests/tasks/templates/test_nli.py +++ b/tests/tasks/templates/test_nli.py @@ -21,7 +21,7 @@ # SOFTWARE. from lighteval.tasks.templates.nli import get_nli_prompt_function -from lighteval.tasks.templates.utils.formulation import CFFormulation, HybridFormulation +from lighteval.tasks.templates.utils.formulation import CFFormulation, HybridFormulation, MCFFormulation from lighteval.utils.language import Language @@ -114,3 +114,85 @@ def test_nli_prompt_hybrid(): assert doc.unconditioned_query == "Answer:" assert doc.choices == [" True", " Neither", " False"] assert doc.gold_index == [2] + + +def test_nli_prompt_mcf_cot(): + """Test multiple-choice format NLI prompt generation with cot.""" + test_input = { + "premise": "The cat is sleeping on the couch.", + "hypothesis": "The cat is awake.", + "gold_idx": 2, + "__few_shots": True, + "few_shot_cot": "i think it's A. True", + "instruction": "Choose the correct relation between the premise and hypothesis", + } + + prompt_fn = get_nli_prompt_function( + Language.ENGLISH, + { + "hypothesis": "hypothesis", + "premise": "premise", + "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + ["entailment", "neutral", "contradiction"], + formulation=MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_nli_task") + + assert ( + doc.query + == """\ +Choose the correct relation between the premise and hypothesis + +The cat is sleeping on the couch. +Question: The cat is awake. + A. True + B. Neither + C. False +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" I think it's A. True"] + assert doc.gold_index == [0] + + +def test_nli_prompt_mcf_cot_default_instruction(): + """Test multiple-choice format NLI prompt generation with cot.""" + test_input = { + "premise": "The cat is sleeping on the couch.", + "hypothesis": "The cat is awake.", + "gold_idx": 2, + "__few_shots": True, + "few_shot_cot": "i think it's A. True", + } + + prompt_fn = get_nli_prompt_function( + Language.ENGLISH, + {"hypothesis": "hypothesis", "premise": "premise", "gold_idx": "gold_idx", "few_shot_cot": "few_shot_cot"}, + ["entailment", "neutral", "contradiction"], + formulation=MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_nli_task") + + assert ( + doc.query + == """\ +Choose the letter of the most likely relation between the premise and hypothesis. +Output the final answer in format: . + +The cat is sleeping on the couch. +Question: The cat is awake. + A. True + B. Neither + C. False +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" I think it's A. True"] + assert doc.gold_index == [0] diff --git a/tests/tasks/templates/test_qa.py b/tests/tasks/templates/test_qa.py new file mode 100644 index 000000000..7c9a4dc52 --- /dev/null +++ b/tests/tasks/templates/test_qa.py @@ -0,0 +1,135 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from lighteval.tasks.templates.qa import ( + QAInput, # Reusing QAInput for simplicity + get_qa_prompt_function, +) +from lighteval.utils.language import Language + + +def test_qa_prompt_cf(): + """ + Tests that QA prompt function works correctly for CF formulation. + """ + test_input = { + "question": "What is 2 + 2?", + "choices": ["4", "5", "6"], + } + + prompt_fn = get_qa_prompt_function( + language=Language.ENGLISH, + adapter=lambda x: QAInput( + question=x["question"], + choices=x["choices"], + ), + cot=False, + ) + + doc = prompt_fn(test_input, "test_task") + assert doc is not None + + assert ( + doc.query + == """\ +Question: What is 2 + 2? +Answer:\ +""" + ) + assert doc.unconditioned_query == "Answer:" + assert doc.choices == [" 4", " 5", " 6"] + assert doc.gold_index == [0, 1, 2] + + +def test_qa_prompt_cf_cot_default_instruction(): + """ + Tests QA with CoT and default instruction, checking for the warning. + """ + test_input = { + "question": "Solve for x: x + 5 = 10", + "choices": ["5", "10", "0"], + } + + prompt_fn = get_qa_prompt_function( + language=Language.ENGLISH, + adapter=lambda x: QAInput( + question=x["question"], + choices=x["choices"], + ), + cot=True, + ) + + doc = prompt_fn(test_input, "test_task") + + assert ( + doc.query + == """\ +Answer the following question. +Output the final answer in format: . + +Question: Solve for x: x + 5 = 10 +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" 5", " 10", " 0"] + assert doc.gold_index == [0, 1, 2] + + +def test_qa_prompt_cf_cot_custom_instruction(): + """ + Tests QA with CoT and custom instruction. + """ + test_input = { + "question": "Solve for x: x + 5 = 10", + "choices": ["5", "10", "0"], + "instruction": "Answer the following question. Output the final answer in format: .", + "few_shot_cot": "1+2+3+4+5=15", + "__few_shots": True, + } + + prompt_fn = get_qa_prompt_function( + language=Language.ENGLISH, + adapter=lambda x: QAInput( + question=x["question"], + choices=x["choices"], + instruction=x["instruction"], + few_shot_cot=x["few_shot_cot"], + ), + cot=True, + ) + + doc = prompt_fn(test_input, "test_task") + + assert ( + doc.query + == """\ +Answer the following question. Output the final answer in format: . + +Question: Solve for x: x + 5 = 10 +Step-by-Step Answer:\ +""" + ) + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" 1+2+3+4+5=15"] + assert doc.gold_index == [0] diff --git a/tests/tasks/templates/test_translation.py b/tests/tasks/templates/test_translation.py index eab59cf18..1089de8ac 100644 --- a/tests/tasks/templates/test_translation.py +++ b/tests/tasks/templates/test_translation.py @@ -21,6 +21,8 @@ # SOFTWARE. +import pytest + from lighteval.tasks.templates.translation import get_translation_prompt_function from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation from lighteval.utils.language import Language @@ -81,6 +83,8 @@ def test_translation_prompt_mcf(): doc.query == """\ CS: Ahoj, jak se máš? FR: + +Options: A. Bonjour, comment allez-vous? B. Ciao, come stai? Answer:\ @@ -118,3 +122,171 @@ def test_translation_prompt_cf_formatting(): assert doc.unconditioned_query == "" assert doc.choices == [" 你好吗?"] assert doc.gold_index == [0] + + +def test_translation_cot_default_instruction(): + """ + Tests that translation prompt function uses default instruction when CoT is set to true. + """ + test_input = { + "source_text": "How are you?", + "target_text": "你好吗?", + } + + prompt_fn = get_translation_prompt_function( + source_language=Language.ENGLISH, + target_language=Language.CHINESE, + adapter=lambda x: { + "source_text": x["source_text"], + "target_text": x["target_text"], + }, + formulation=CFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc is not None + + # Check that the default instruction is included + expected_instruction = "Translate the following text from English to Chinese.\n" + assert doc.query.startswith(expected_instruction) + assert "EN: How are you? ZH:" in doc.query + assert doc.choices == [" 你好吗?"] + assert doc.gold_index == [0] + + +def test_translation_cot_default_instruction_mcf(): + """ + Tests that translation prompt function uses default instruction when CoT is set to true for MCF formulation. + """ + test_input = { + "source_text": "Ahoj, jak se máš?", + "target_text": ["Bonjour, comment allez-vous?", "Ciao, come stai?"], + } + + prompt_fn = get_translation_prompt_function( + source_language=Language.CZECH, + target_language=Language.FRENCH, + adapter=lambda x: { + "source_text": x["source_text"], + "target_text": x["target_text"], + "gold_idx": 0, + }, + formulation=MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc is not None + + # Check that both default instructions are included + expected_instructions = "Choose the letter of the correct answer.\nOutput the final answer in format: .\n\n" + assert doc.query.startswith(expected_instructions) + assert "CS: Ahoj, jak se máš? FR:" in doc.query + assert "A. Bonjour, comment allez-vous?" in doc.query + assert "B. Ciao, come stai?" in doc.query + assert doc.choices == [" A", " B"] + assert doc.gold_index == [0] + + +def test_translation_cot_user_instruction(): + """ + Tests that translation prompt function uses user provided instruction when available. + """ + test_input = { + "source_text": "How are you?", + "target_text": "你好吗?", + "instruction": "Please translate this English text to Chinese:", + } + + prompt_fn = get_translation_prompt_function( + source_language=Language.ENGLISH, + target_language=Language.CHINESE, + adapter=lambda x: { + "source_text": x["source_text"], + "target_text": x["target_text"], + "instruction": x["instruction"], + }, + formulation=CFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task") + assert doc is not None + + # Check that the user instruction is included with formatting instruction + expected_instructions = "Please translate this English text to Chinese:\n\n" + assert doc.query.startswith(expected_instructions) + assert "EN: How are you? ZH:" in doc.query + assert doc.choices == [" 你好吗?"] + assert doc.gold_index == [0] + + +def test_translation_cot_mcf_number_prefix_error(): + """ + Tests that translation prompt function raises an error when using CoT with MCF and Number choice prefix. + """ + test_input = { + "source_text": "How are you?", + "target_text": "你好吗?", + } + + with pytest.raises(ValueError, match="You are using a COT with a unsupported formulation"): + prompt_fn = get_translation_prompt_function( + source_language=Language.ENGLISH, + target_language=Language.CHINESE, + adapter=lambda x: { + "source_text": x["source_text"], + "target_text": x["target_text"], + "gold_idx": 0, + }, + formulation=MCFFormulation(cot=True, choice_prefix="Numbers"), + ) + + prompt_fn(test_input, "test_task") + + +def test_translation_prompt_mcf_cot(): + """ + Tests that translation prompt function works correctly for both cause/effect. + Since it's pretty much a wrapper around continuation template we just test single formulation. + + """ + test_input = { + "source_text": "How are you?", + "target_text": ["你好吗?", "你怎么样?"], + "__few_shots": True, + "few_shot_cot": "i think it's A.", + "gold_idx": 0, + "instruction": "Choose the letter of the most likely continuation.", + } + + prompt_fn = get_translation_prompt_function( + Language.ENGLISH, + Language.CHINESE, + { + "source_text": "source_text", + "target_text": "target_text", + "gold_idx": "gold_idx", + "few_shot_cot": "few_shot_cot", + "instruction": "instruction", + }, + MCFFormulation(cot=True), + ) + + doc = prompt_fn(test_input, "test_task") + + assert ( + doc.query + == """\ +Choose the letter of the most likely continuation. + +EN: How are you? ZH: + +Options: + A. 你好吗? + B. 你怎么样? +Step-by-Step Answer:\ +""" + ) + + assert doc.unconditioned_query == "Step-by-Step Answer:" + assert doc.choices == [" I think it's A."] + assert doc.gold_index == [0]