diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py
index 3e0b45121..0112ef4f1 100644
--- a/src/lighteval/metrics/dynamic_metrics.py
+++ b/src/lighteval/metrics/dynamic_metrics.py
@@ -25,7 +25,9 @@
import numpy as np
+from lighteval.metrics.metrics_corpus import CorpusLevelTranslationMetric
from lighteval.metrics.metrics_sample import (
+ BLEU,
ExactMatches,
F1_score,
LoglikelihoodAcc,
@@ -38,6 +40,7 @@
LogProbTokenNorm,
get_multilingual_normalizer,
)
+from lighteval.metrics.sample_preparator import GenerativePreparator
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
ExprExtractionConfig,
ExtractionTarget,
@@ -47,7 +50,7 @@
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparison import compare_gold_target
-from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
+from lighteval.metrics.utils.metric_utils import CorpusLevelMetric, MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout
@@ -122,7 +125,10 @@ def probability_metric(
def multilingual_quasi_f1_score_metric(
- language: Language, aggregation_function: Callable[[list[float]], float] = max
+ language: Language,
+ aggregation_function: Callable[[list[float]], float] = max,
+ normalize_gold: Callable[[str], str] | None = None,
+ normalize_pred: Callable[[str], str] | None = None,
) -> SampleLevelMetric:
"""
Creates a language-aware F1 score metric, which returns the F1 score.
@@ -130,18 +136,23 @@ def multilingual_quasi_f1_score_metric(
Args:
language: The language of the samples.
aggregation_function: Aggregation samples to use when multiple golds are present.
+ normalize_gold: Normalization function for gold answers.
+ normalize_pred: Normalization function for predictions.
Returns:
F1 score metric.
"""
metric_name = f"f1_{language.value}"
- multilang_normalizer = get_multilingual_normalizer(language)
+ base_normalizer = get_multilingual_normalizer(language)
+ gold_normalizer = (lambda x: base_normalizer(normalize_gold(x))) if normalize_gold is not None else base_normalizer
+ pred_normalizer = (lambda x: base_normalizer(normalize_pred(x))) if normalize_pred is not None else base_normalizer
+
return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=F1_score(
- normalize_gold=multilang_normalizer,
- normalize_pred=multilang_normalizer,
+ normalize_gold=gold_normalizer,
+ normalize_pred=pred_normalizer,
aggregation_function=aggregation_function,
).compute,
category=MetricCategory.GENERATIVE,
@@ -155,6 +166,8 @@ def multilingual_quasi_exact_match_metric(
language: Language,
match_type: Literal["prefix", "suffix", "full"] = "full",
aggregation_function: Callable[[list[float]], float] = max,
+ normalize_gold: Callable[[str], str] | None = None,
+ normalize_pred: Callable[[str], str] | None = None,
) -> SampleLevelMetric:
"""
Creates a language-aware exact match metric, which returns the exact match score
@@ -165,16 +178,21 @@ def multilingual_quasi_exact_match_metric(
- "suffix": Suffixes must match
- "full": Full strings must match
aggregation_function: Aggregation samples to use when multiple golds are present.
+ normalize_gold: Normalization function for gold answers.
+ normalize_pred: Normalization function for predictions.
Returns:
Exact match metric.
"""
metric_name = f"exact_match_{language.value}_{match_type}"
- multilang_normalizer = get_multilingual_normalizer(language)
+ base_normalizer = get_multilingual_normalizer(language)
+ gold_normalizer = (lambda x: base_normalizer(normalize_gold(x))) if normalize_gold is not None else base_normalizer
+ pred_normalizer = (lambda x: base_normalizer(normalize_pred(x))) if normalize_pred is not None else base_normalizer
+
return SampleLevelMetric(
metric_name=metric_name,
sample_level_fn=ExactMatches(
- normalize_gold=multilang_normalizer,
- normalize_pred=multilang_normalizer,
+ normalize_gold=gold_normalizer,
+ normalize_pred=pred_normalizer,
aggregation_function=aggregation_function,
type_exact_match=match_type,
).compute,
@@ -185,6 +203,38 @@ def multilingual_quasi_exact_match_metric(
)
+def translation_metric(
+ metric_name: Literal["bleu", "bleu_1", "bleu_4", "chrf", "chrf++"],
+ normalize_pred: Callable[[str], str] | None = None,
+ normalize_gold: Callable[[str], str] | None = None,
+) -> CorpusLevelMetric | SampleLevelMetric:
+ """
+ Creates a translation metric, which returns the translation score.
+ """
+ if metric_name.startswith("bleu_"):
+ return SampleLevelMetric(
+ metric_name=metric_name,
+ sample_level_fn=BLEU(
+ n_gram=int(metric_name.split("_")[1]), normalize_pred=normalize_pred, normalize_gold=normalize_gold
+ ).compute,
+ category=MetricCategory.GENERATIVE,
+ use_case=MetricUseCase.TRANSLATION,
+ corpus_level_fn=np.mean,
+ higher_is_better=True,
+ )
+ else:
+ return CorpusLevelMetric(
+ metric_name=metric_name,
+ sample_level_fn=GenerativePreparator().prepare,
+ category=MetricCategory.GENERATIVE,
+ use_case=MetricUseCase.TRANSLATION,
+ corpus_level_fn=CorpusLevelTranslationMetric(
+ metric_name, normalize_pred=normalize_pred, normalize_gold=normalize_gold
+ ).compute, # type: ignore
+ higher_is_better=True,
+ )
+
+
def multilingual_extractive_match_metric(
language: Language = Language.ENGLISH,
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py
index 030725a53..b0d488f8f 100644
--- a/src/lighteval/metrics/metrics_corpus.py
+++ b/src/lighteval/metrics/metrics_corpus.py
@@ -27,7 +27,7 @@
import logging
import math
-from typing import Literal
+from typing import Callable, Literal
import numpy as np
import sacrebleu
@@ -91,7 +91,13 @@ def compute(self, items: list[LogprobCorpusMetricInput]):
class CorpusLevelTranslationMetric:
- def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
+ def __init__(
+ self,
+ metric_type: Literal["bleu", "chrf", "chrf++", "ter"],
+ lang: Literal["zh", "ja", "ko", ""] = "",
+ normalize_pred: Callable[[str], str] | None = None,
+ normalize_gold: Callable[[str], str] | None = None,
+ ):
"""Stores the relevant parameters for a corpus level translation metric.
Args:
@@ -99,6 +105,8 @@ def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
"""
self.metric_type = metric_type
self.lang = lang
+ self.normalize_pred = normalize_pred if normalize_pred is not None else lambda x: x
+ self.normalize_gold = normalize_gold if normalize_gold is not None else lambda x: x
def get_metric(self):
if self.metric_type == "bleu":
@@ -115,7 +123,7 @@ def get_metric(self):
def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
metric = self.get_metric()
- golds = [i.golds for i in items]
+ golds = [[self.normalize_gold(gold) for gold in i.golds] for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
@@ -123,7 +131,7 @@ def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
logger.info(
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
)
- preds.append(pred[0])
+ preds.append(self.normalize_pred(pred[0]))
return float(metric.corpus_score(hypotheses=preds, references=golds).score)
diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py
index 984e2607a..795aa0918 100644
--- a/src/lighteval/metrics/metrics_sample.py
+++ b/src/lighteval/metrics/metrics_sample.py
@@ -744,7 +744,12 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float:
class BLEU:
- def __init__(self, n_gram: int):
+ def __init__(
+ self,
+ n_gram: int,
+ normalize_pred: Callable[[str], str] | None = None,
+ normalize_gold: Callable[[str], str] | None = None,
+ ):
"""BLEU scorer class. Relies on `nltk`'s sentencebleu for scoring.
TODO: Will have to move this to sacrebleu.
@@ -752,6 +757,8 @@ def __init__(self, n_gram: int):
n_gram (int): Number of n_grams to use for scoring.
"""
self.n_gram = n_gram
+ self.normalize_pred = normalize_pred
+ self.normalize_gold = normalize_gold
def compute(self, golds: list[str], predictions: list[str], **kwargs):
"""Computes the sentence level BLEU between the golds and each prediction, then takes the average.
@@ -763,6 +770,10 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs):
Returns:
float: Score over the current sample's items.
"""
+ if self.normalize_pred:
+ predictions = [self.normalize_pred(p) for p in predictions]
+ if self.normalize_gold:
+ golds = [self.normalize_gold(g) for g in golds]
return np.mean([self._bleu_score(golds, p) for p in predictions])
def _bleu_score(self, gold: list[str], pred: str) -> float:
diff --git a/src/lighteval/metrics/utils/extractive_match_utils.py b/src/lighteval/metrics/utils/extractive_match_utils.py
index b8c529d05..799b580de 100644
--- a/src/lighteval/metrics/utils/extractive_match_utils.py
+++ b/src/lighteval/metrics/utils/extractive_match_utils.py
@@ -89,6 +89,7 @@ class IndicesExtractionConfig:
"""
prefix_for_extraction: ChoicePrefix
+ bb_match_priority: int = -1
try_extract_without_anchor: bool = True
@@ -340,6 +341,9 @@ def lazy_indices_regex(
]
)
+ if indices_config.bb_match_priority >= 0:
+ regexes.append((rf"\s*{indice_str_re}\s*", indices_config.bb_match_priority))
+
return [(re.compile(pattern), priority) for pattern, priority in regexes]
diff --git a/src/lighteval/tasks/multilingual/adapters.py b/src/lighteval/tasks/multilingual/adapters.py
index 59b39fcd6..0ca18479c 100644
--- a/src/lighteval/tasks/multilingual/adapters.py
+++ b/src/lighteval/tasks/multilingual/adapters.py
@@ -29,6 +29,7 @@
from lighteval.tasks.default_prompts import LETTER_INDICES
from lighteval.tasks.multilingual.utils.adapters_utils import (
extract_answers_from_string,
+ float_to_choice_string,
multichoice_join,
multichoice_to_single_choice,
)
@@ -79,7 +80,7 @@ def thai_exams_adapter(line: dict) -> MCQInput | None:
def alghafa_adapter(line: dict) -> MCQInput | None:
answer_index = int(line["label"])
- choices_keys = [key for key in line.keys() if key not in ["query", "label", "__few_shots"]]
+ choices_keys = [key for key in line.keys() if key not in ["query", "label", "__index", "__few_shots"]]
choices = [line[key] for key in choices_keys]
return {
"question": line["query"],
@@ -298,3 +299,55 @@ def enem_adapter(lang: Language, line: dict) -> MCQInput | None:
"choices": line["alternatives"],
"gold_idx": LETTER_INDICES.index(line["label"]),
}
+
+
+CMM_MATH_ANSWER_RE = re.compile(r"([A-D])\.(.*?)(?=[A-D]\.|$)")
+
+
+def cmm_math_adapter(line: dict) -> MCQInput | None:
+ """Adapter for CMM-Math dataset.
+
+ Processes questions and options, handling cases where:
+ - Question ends with parentheses that need to be stripped
+ - Options are space-separated strings starting with A./B./C./D.
+ """
+ # Strip ending parentheses from question
+ question = line["question"].strip().rstrip("( )")
+
+ # Split options and store as dict with letter keys
+ choices = {}
+ for match in CMM_MATH_ANSWER_RE.finditer(line["options"]):
+ letter, choice = match.groups()
+ choices[letter] = choice.strip()
+
+ try:
+ gold_idx = list(choices.keys()).index(line["answer"])
+ except ValueError:
+ gold_idx = None
+
+ # Validate we have enough options and answer
+ if len(choices) <= 1 or not line.get("answer") or gold_idx is None:
+ return None
+
+ return {"question": question, "choices": list(choices.values()), "gold_idx": gold_idx}
+
+
+def qazuntv2_adapter(line: dict) -> MCQInput | None:
+ gold_idx = LETTER_INDICES.index(line["answer"])
+ choices = line["options"]
+ if gold_idx >= len(choices):
+ return None
+ return {"question": line["question"], "choices": choices, "gold_idx": gold_idx}
+
+
+MGSM_COT_PREFIX_RE = re.compile(
+ r"\s*(ধাপে ধাপে উত্তর|Schritt-für-Schritt-Antwort|Step-by-Step Answer|Respuesta paso a paso|Réponse étape par étape|ステップごとの答え|Пошаговое решение|Jibu la Hatua kwa Hatua|దశలవారీగా సమాధానంi|คำตอบทีละขั้นตอน|逐步解答)\s*:\s*"
+)
+MGSM_QUESTION_RE = re.compile(r"\s*(প্রশ্ন|Frage|Question|Pregunta|Question|問題|Задача|Swali|ప్రశ్న|โจทย์|问题)\s*:\s*")
+
+
+def mgsm_adapter(line: dict) -> QAInput | None:
+ question = MGSM_QUESTION_RE.sub("", line["question"])
+ answer_cot = MGSM_COT_PREFIX_RE.sub("", line["answer"]) if line["answer"] else ""
+ answer_number = line["answer_number"]
+ return {"question": question, "few_shot_cot": answer_cot, "choices": [float_to_choice_string(answer_number)]}
diff --git a/src/lighteval/tasks/multilingual/tasks.py b/src/lighteval/tasks/multilingual/tasks.py
index e94a1e400..70e3466dc 100644
--- a/src/lighteval/tasks/multilingual/tasks.py
+++ b/src/lighteval/tasks/multilingual/tasks.py
@@ -28,8 +28,10 @@
from lighteval.metrics.dynamic_metrics import (
loglikelihood_acc_metric,
+ multilingual_extractive_match_metric,
multilingual_quasi_exact_match_metric,
multilingual_quasi_f1_score_metric,
+ translation_metric,
)
from lighteval.metrics.metrics import Metrics
from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm
@@ -39,25 +41,36 @@
agieval_adapter,
alghafa_adapter,
ceval_adapter,
- enem_adapter,
+ cmm_math_adapter,
get_m3exam_adapter,
get_mkqa_adapter,
+ mgsm_adapter,
+ qazuntv2_adapter,
sciqa_adapter,
thai_exams_adapter,
winogrand_adapter,
xcodah_adapter,
)
-from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation, normalize_subset
+from lighteval.tasks.multilingual.utils.adapters_utils import float_to_choice_string
+from lighteval.tasks.multilingual.utils.task_utils import (
+ get_cot_answer_normalization,
+ get_cot_generaion_size,
+ get_metrics_for_mcq_formulation,
+ get_stop_sequence,
+ normalize_subset,
+)
from lighteval.tasks.templates.boolq import get_boolq_prompt_function
from lighteval.tasks.templates.continuation import get_continuation_prompt_function
from lighteval.tasks.templates.copa import get_copa_prompt_function
from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function
+from lighteval.tasks.templates.math_qa import get_math_qa_prompt_function
from lighteval.tasks.templates.multichoice import get_mcq_prompt_function
from lighteval.tasks.templates.nli import get_nli_prompt_function
from lighteval.tasks.templates.qa import get_qa_prompt_function
from lighteval.tasks.templates.translation import get_translation_prompt_function
from lighteval.tasks.templates.utils.formulation import (
CFFormulation,
+ Formulation,
HybridFormulation,
MCFFormulation,
)
@@ -65,7 +78,23 @@
from lighteval.utils.language import Language, iso_639_3_ind_to_iso_639_3_macro, manage_duplicate_language_codes
+# Philospohy of formulations:
+# 1. For early-stage pretrained model, we recommend using CF formulation with few-shots (task_cf_native), this allows to get reasonable signal even at this stage
+# 2. For later stage, we recommend using MCF formulation with few-shots (task_mcf_native), as models at this point should be able to do MCF formulation
+# 3. For post-trained models, we recommend using MCF formulation without few-shots with cot (task_mcf_cot_native), this allows the best match to their real usage and they should be capable to
+# follow expected format
+
+# Similarly for generative tasks, we recommend using non-cot variants for all pre-trained models, and cot variants for post-trained models
+
+
TASKS_TABLE = []
+DEFAULT_FORMULATIONS: list[Formulation] = [
+ MCFFormulation(),
+ MCFFormulation(cot=True),
+ MCFFormulation("NativeLetters"),
+ MCFFormulation("NativeLetters", cot=True),
+ CFFormulation(),
+]
# ------------------------------- NLI Tasks ------------------------------- #
# NLI (Natural Language Inference) tasks involve determining the logical relationship
# between two given sentences: a premise and a hypothesis. The goal is to classify
@@ -76,12 +105,41 @@
# The XNLI dataset is a multilingual variant of MultiNLI
# https://aclanthology.org/D18-1269/
+
+
+def get_formulation_name(formulation: Formulation) -> str:
+ match formulation:
+ case MCFFormulation("NativeLetters") | HybridFormulation("NativeLetters"):
+ name = formulation.name.lower() + "_native"
+ case MCFFormulation("Numbers") | HybridFormulation("Numbers"):
+ name = formulation.name.lower() + "_numbers"
+ case _:
+ name = formulation.name.lower()
+ if formulation.cot:
+ name += "_cot"
+ return name
+
+
+def get_mcf_task_name(base_name: str, language: Language | str, formulation: Formulation | None) -> str:
+ """Helper function to generate consistent task names."""
+ formulation_name = f"_{get_formulation_name(formulation)}" if formulation else ""
+
+ language_name = language.value if isinstance(language, Language) else language
+ return f"{base_name}_{language_name}{formulation_name}"
+
+
+def get_generative_task_name(base_name: str, language: Language | str, cot: bool) -> str:
+ language_name = language.value if isinstance(language, Language) else language
+ return f"{base_name}_{language_name}{'_cot' if cot else ''}"
+
+
xnli_tasks = [
LightevalTaskConfig(
- name=f"xnli_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("xnli", language, formulation),
suite=["lighteval"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
@@ -104,6 +162,7 @@
hf_subset=standardize_tag(language.value),
evaluation_splits=["validation"],
few_shots_split="train",
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ARABIC,
@@ -124,19 +183,19 @@
Language.VIETNAMESE,
Language.CHINESE,
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
-
# Improvement on XNLI with better translation, from our experience models tend to
# perform better on XNLI2.0 than XNLI
# https://arxiv.org/abs/2301.06527
xnli2_tasks = [
LightevalTaskConfig(
- name=f"xnli2.0_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("xnli2.0", language, formulation),
suite=["lighteval"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
@@ -161,6 +220,7 @@
hf_subset="default",
evaluation_splits=["train"],
hf_avail_splits=["train"],
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ENGLISH,
@@ -189,14 +249,14 @@
Language.ARABIC,
# Theoretically also: Bhojpuri, Gujarati, Odiya
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# Another variant of XNLI, with emphasis on Indic languages
# https://arxiv.org/abs/2204.08776
xnli_indic_tasks = [
LightevalTaskConfig(
- name=f"indicnxnli_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("indicxnli", language, formulation),
suite=["lighteval"],
prompt_function=get_nli_prompt_function(
language=language,
@@ -215,14 +275,16 @@
hf_filter=lambda x: int(x["label"]) in [0, 2],
evaluation_splits=["validation"],
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ASSAMESE,
@@ -237,14 +299,14 @@
Language.TAMIL,
Language.TELUGU,
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# African XNLI: African XNLI
# From https://arxiv.org/abs/2406.03368. Human translated MMLU.
afri_xnli_tasks = [
LightevalTaskConfig(
- name=f"afri_xnli_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("afri_xnli", language, formulation),
suite=("lighteval",),
prompt_function=get_nli_prompt_function(
language=language,
@@ -262,14 +324,16 @@
hf_filter=lambda x: int(x["label"]) in [0, 2],
evaluation_splits=("test",),
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.AMHARIC,
@@ -290,7 +354,7 @@
Language.YORUBA,
# Language.ZULU,
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# PAWS-X: A Cross-lingual Adversarial Dataset for Paraphrase Identification
@@ -301,7 +365,7 @@
paws_x_tasks = [
LightevalTaskConfig(
- name=f"pawsx_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("pawsx", language, formulation),
suite=("lighteval",),
prompt_function=get_nli_prompt_function(
language=language,
@@ -318,14 +382,16 @@
hf_subset=standardize_tag(language.value),
evaluation_splits=("test",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.GERMAN,
@@ -336,7 +402,7 @@
Language.KOREAN,
Language.CHINESE,
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# Russian Commitment Bank (RCB) is a large-scale NLI dataset with Russian sentences,
@@ -344,13 +410,12 @@
# https://arxiv.org/abs/2401.04531
rcb_tasks = [
LightevalTaskConfig(
- name=f"rcb_{Language.RUSSIAN.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("rcb", Language.RUSSIAN, formulation),
prompt_function=get_nli_prompt_function(
language=Language.RUSSIAN,
adapter=lambda line: {
"premise": line["inputs"]["premise"],
"hypothesis": line["inputs"]["hypothesis"],
- # Since we ignore the neutral label
"gold_idx": int(line["outputs"]) - 1,
},
relations=["entailment", "contradiction"],
@@ -359,20 +424,21 @@
suite=("lighteval",),
hf_repo="ai-forever/MERA",
hf_subset="rcb",
- # Ignore neutral label
hf_filter=lambda x: int(x["outputs"] or "0") in [1, 2],
evaluation_splits=("train",),
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.RUSSIAN,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# Native Chinese NLI dataset based.
@@ -380,13 +446,12 @@
# We find this benchmark to have really good signal compared to other Chinese NLI
ocnli_tasks = [
LightevalTaskConfig(
- name=f"ocnli_{Language.CHINESE.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("ocnli", Language.CHINESE, formulation),
prompt_function=get_nli_prompt_function(
language=Language.CHINESE,
adapter=lambda line: {
"premise": line["sentence1"],
"hypothesis": line["sentence2"],
- # Since we ignore the neutral label
"gold_idx": {1: 0, 2: 1}[line["label"]],
},
relations=["entailment", "contradiction"],
@@ -395,33 +460,33 @@
suite=("lighteval",),
hf_repo="clue/clue",
hf_subset="ocnli",
- # Only keep the positive and negative examples
hf_filter=lambda x: int(x["label"]) in [1, 2],
evaluation_splits=("validation",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.CHINESE,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# https://arxiv.org/abs/2004.05986
# Native Chinese NLI dataset based on MNLI approach (Machine Translated)
cmnli_tasks = [
LightevalTaskConfig(
- name=f"cmnli_{Language.CHINESE.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("cmnli", Language.CHINESE, formulation),
prompt_function=get_nli_prompt_function(
language=Language.CHINESE,
adapter=lambda line: {
"premise": line["sentence1"],
"hypothesis": line["sentence2"],
- # Since we ignore the neutral label
"gold_idx": {"entailment": 0, "contradiction": 1}[line["label"]],
},
relations=["entailment", "contradiction"],
@@ -431,19 +496,20 @@
hf_repo="fenffef/cmnli",
hf_subset="default",
hf_filter=lambda x: x["label"] in ["entailment", "contradiction"],
- # Only keep the positive and negative examples
evaluation_splits=("validation",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.CHINESE,
[
loglikelihood_acc_metric(normalization=None),
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -467,7 +533,7 @@
# XCOPA extends the original English COPA task to 11 typologically diverse languages.
xcopa_tasks = [
LightevalTaskConfig(
- name=f"xcopa_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("xcopa", language, formulation),
suite=["lighteval"],
prompt_function=get_copa_prompt_function(
language,
@@ -483,13 +549,15 @@
hf_subset=("copa_ext_ar" if language == Language.ARABIC else standardize_tag(language.value)),
evaluation_splits=["test"],
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ARABIC,
@@ -505,7 +573,7 @@
Language.HAITIAN,
Language.QUECHUA,
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# IndicCOPA: COPA for Indic Languages
@@ -514,7 +582,7 @@
# evaluating common sense reasoning in these languages.
copa_indic_tasks = [
LightevalTaskConfig(
- name=f"indicxcopa_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("indicxcopa", language, formulation),
suite=["lighteval"],
prompt_function=get_copa_prompt_function(
language,
@@ -533,14 +601,16 @@
hf_revision="d356ef19a4eb287e88a51d07a56b73ba88c7f188",
evaluation_splits=["test"],
hf_avail_splits=["test"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
trust_dataset=True,
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ASSAMESE,
@@ -560,7 +630,7 @@
Language.URDU,
# Optionally: Maithili, Santali, Sindhi, Konkani
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# PARus: Plausible Alternatives for Russian
@@ -569,7 +639,7 @@
# It evaluates common sense reasoning and causal inference abilities in Russian language models.
parus_tasks = [
LightevalTaskConfig(
- name=f"parus_{Language.RUSSIAN.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("parus", Language.RUSSIAN, formulation),
suite=["lighteval"],
prompt_function=get_copa_prompt_function(
language=Language.RUSSIAN,
@@ -585,17 +655,21 @@
hf_subset="parus",
evaluation_splits=["train"],
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.RUSSIAN,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
+# Rerun hellaswag tasks
+# Add xcopa to thai
TASKS_TABLE.extend([*xcopa_tasks, *copa_indic_tasks, *parus_tasks])
# ------------------------------- Hellaswag Tasks ------------------------------- #
@@ -609,12 +683,12 @@
# It evaluates commonsense reasoning abilities across multiple languages.
mlmm_hellaswag_tasks = [
LightevalTaskConfig(
- name=f"mlmm_hellaswag_{lang.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("mlmm_hellaswag", lang, formulation),
suite=["lighteval"],
prompt_function=get_hellaswag_prompt_function(
language=lang,
adapter=lambda line: {
- # We don't use activity_label as they are not available
+ # activity_label is only available in english, thus we don't use it
"ctx_a": line["ctx_a"],
"ctx_b": line["ctx_b"],
"continuations": line["endings"],
@@ -629,14 +703,16 @@
hf_revision="96ed8e0dfc6172dad1d3df338d7b8ba6c1ff9d83",
evaluation_splits=["validation"],
hf_avail_splits=["validation"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ lang,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
trust_dataset=True,
+ stop_sequence=get_stop_sequence(lang, formulation.cot),
)
for lang in [
Language.ARABIC,
@@ -673,7 +749,7 @@
Language.VIETNAMESE,
Language.CHINESE,
]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# Hellaswag Turkish
@@ -685,13 +761,12 @@
# which would make it hard to read
hellaswag_tur_tasks = [
LightevalTaskConfig(
- name=f"community_hellaswag_{Language.TURKISH.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("community_hellaswag", Language.TURKISH, formulation),
suite=["lighteval"],
prompt_function=get_hellaswag_prompt_function(
language=Language.TURKISH,
adapter=lambda line: {
- "ctx_a": line["ctx_a"],
- "ctx_b": line["ctx_b"],
+ "ctx_a": line["ctx"],
"continuations": line["endings"],
"gold_idx": int(line["label"]),
},
@@ -703,15 +778,17 @@
hf_subset="default",
evaluation_splits=["validation"],
hf_avail_splits=["validation"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.TURKISH,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
# Hellaswag Thai
@@ -720,13 +797,13 @@
# for evaluating Thai language models on commonsense reasoning tasks.
hellaswag_tha_tasks = [
LightevalTaskConfig(
- name=f"community_hellaswag_{Language.THAI.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("community_hellaswag", Language.THAI, formulation),
suite=["lighteval"],
prompt_function=get_hellaswag_prompt_function(
language=Language.THAI,
adapter=lambda line: {
- "ctx_a": line["ctx_a"],
- "ctx_b": line["ctx_b"],
+ "activity_label": line["activity_label"],
+ "ctx_a": line["ctx"],
"continuations": line["endings"],
"gold_idx": int(line["label"]),
},
@@ -737,72 +814,49 @@
hf_subset="default",
evaluation_splits=["validation"],
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.THAI,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.THAI, formulation.cot),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
-]
-
-hellaswag_hin_tasks = [
- LightevalTaskConfig(
- name=f"community_hellaswag_{Language.HINDI.value}_{formulation.name.lower()}",
- suite=["lighteval"],
- prompt_function=get_hellaswag_prompt_function(
- language=Language.HINDI,
- adapter=lambda line: {
- "ctx_a": line["ctx_a"],
- "continuations": line["endings"],
- "gold_idx": int(line["label"]),
- },
- formulation=formulation,
- ),
- hf_repo="ai4bharat/hellaswag-hi",
- hf_filter=lambda line: all(len(choice.strip()) > 0 for choice in line["endings"]),
- hf_subset="hi",
- evaluation_splits=("validation",),
- few_shots_split="validation",
- metric=get_metrics_for_formulation(
- formulation,
- [
- loglikelihood_acc_metric(normalization=LogProbCharNorm()),
- loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
- ],
- ),
- )
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
hellaswag_tel_tasks = [
LightevalTaskConfig(
- name=f"community_hellaswag_{Language.TELUGU.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("community_hellaswag", Language.TELUGU, formulation),
suite=["lighteval"],
prompt_function=get_hellaswag_prompt_function(
language=Language.TELUGU,
adapter=lambda line: {
+ # Activity label is only available in english, thus we don't use it
"ctx_a": line["ctx_a"],
- "continuations": line["endings"],
+ "continuations": [line["a"], line["b"], line["c"], line["d"]],
"gold_idx": int(line["label"]),
},
formulation=formulation,
+ wikihow_artifacts=[" [శీర్షిక]", " [హెడర్]", " [header]", " [Header]"],
),
- hf_repo="LightFury9/hellaswag-telugu",
+ hf_repo="indiehackers/hellaswag-telugu-custom-2k",
hf_subset="default",
- evaluation_splits=("valid",),
- few_shots_split="train",
- metric=get_metrics_for_formulation(
+ evaluation_splits=("train",),
+ hf_avail_splits=("train",),
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.TELUGU,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.TELUGU, formulation.cot),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -810,7 +864,6 @@
*mlmm_hellaswag_tasks,
*hellaswag_tur_tasks,
*hellaswag_tha_tasks,
- *hellaswag_hin_tasks,
*hellaswag_tel_tasks,
]
)
@@ -825,7 +878,7 @@
# https://arxiv.org/abs/1910.11856
xquad_tasks = [
LightevalTaskConfig(
- name=f"xquad_{language.value}",
+ name=get_generative_task_name("xquad", language, cot),
prompt_function=get_qa_prompt_function(
language,
lambda line: {
@@ -833,18 +886,21 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="google/xquad",
hf_subset=f"xquad.{standardize_tag(language.value)}",
evaluation_splits=("validation",),
few_shots_split="validation",
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=(
- multilingual_quasi_exact_match_metric(language, "prefix"),
- multilingual_quasi_f1_score_metric(language),
+ multilingual_quasi_exact_match_metric(
+ language, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)),
),
+ stop_sequence=get_stop_sequence(language, cot),
)
for language in [
Language.ARABIC,
@@ -860,72 +916,13 @@
Language.VIETNAMESE,
Language.CHINESE,
]
+ for cot in [False, True]
]
-# GermanQuAD: High-quality German QA dataset with 13,722 questions
-# https://arxiv.org/abs/2104.12741
-germanquad_tasks = [
- LightevalTaskConfig(
- name=f"germanquad_{Language.GERMAN.value}",
- prompt_function=get_qa_prompt_function(
- Language.GERMAN,
- lambda line: {
- "question": line["question"],
- "context": line["context"],
- "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
- },
- ),
- suite=("lighteval",),
- hf_repo="deepset/germanquad",
- hf_subset="plain_text",
- trust_dataset=True,
- hf_revision="fff05ceaf2ffbe5b65c7e0c57e678f7b7e1a0581",
- hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]),
- evaluation_splits=("test",),
- few_shots_split="train",
- generation_size=400,
- stop_sequence=("\n",),
- metric=(
- multilingual_quasi_exact_match_metric(Language.GERMAN, "prefix"),
- multilingual_quasi_f1_score_metric(Language.GERMAN),
- ),
- )
-]
-
-
-# SQuAD-it: Italian translation of the SQuAD dataset
-# https://github.com/crux82/squad-it
-squad_it_tasks = [
- LightevalTaskConfig(
- name=f"squad_{Language.ITALIAN.value}",
- prompt_function=get_qa_prompt_function(
- Language.ITALIAN,
- lambda line: {
- "question": line["question"],
- "context": line["context"],
- "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
- },
- ),
- suite=("lighteval",),
- hf_repo="crux82/squad_it",
- hf_subset="default",
- hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]),
- evaluation_splits=("test",),
- few_shots_split="train",
- generation_size=400,
- stop_sequence=("\n",),
- metric=(
- multilingual_quasi_exact_match_metric(Language.ITALIAN, "prefix"),
- multilingual_quasi_f1_score_metric(Language.ITALIAN),
- ),
- )
-]
-
-
# ThaiQA: A question answering dataset for the Thai language.
thaiqa_tasks = [
LightevalTaskConfig(
- name=f"thaiqa_{Language.THAI.value}",
+ name=get_generative_task_name("thaiqa", Language.THAI, cot),
prompt_function=get_qa_prompt_function(
Language.THAI,
lambda line: {
@@ -933,26 +930,30 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["answer"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="lighteval/thaiqa_squad_fixed",
hf_subset="default",
evaluation_splits=("train",),
few_shots_split="validation",
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
+ stop_sequence=get_stop_sequence(Language.THAI, cot),
metric=(
- multilingual_quasi_exact_match_metric(Language.THAI, "prefix"),
- multilingual_quasi_f1_score_metric(Language.THAI),
+ multilingual_quasi_exact_match_metric(
+ Language.THAI, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.THAI, normalize_pred=get_cot_answer_normalization(cot)),
),
)
+ for cot in [False, True]
]
# SberQuAD: A large-scale Russian reading comprehension dataset.
# https://arxiv.org/abs/1912.09723
sber_squad_tasks = [
LightevalTaskConfig(
- name=f"sber_squad_{Language.RUSSIAN.value}",
+ name=get_generative_task_name("sber_squad", Language.RUSSIAN, cot),
prompt_function=get_qa_prompt_function(
Language.RUSSIAN,
lambda line: {
@@ -960,6 +961,7 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="kuznetsoffandrey/sberquad",
@@ -967,79 +969,22 @@
evaluation_splits=("validation",),
few_shots_split="train",
metric=(
- multilingual_quasi_exact_match_metric(Language.RUSSIAN, "prefix"),
- multilingual_quasi_f1_score_metric(Language.RUSSIAN),
- ),
- generation_size=400,
- stop_sequence=("\n",),
- )
-]
-
-# FaQuAD: A Portuguese Reading Comprehension Dataset
-# https://arxiv.org/abs/2007.15671
-faquad_tasks = [
- LightevalTaskConfig(
- name=f"faquad_{Language.PORTUGUESE.value}",
- prompt_function=get_qa_prompt_function(
- Language.PORTUGUESE,
- lambda line: {
- "question": line["question"],
- "context": line["context"],
- "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
- },
- ),
- suite=("lighteval",),
- hf_repo="eraldoluis/faquad",
- hf_subset="plain_text",
- trust_dataset=True,
- hf_revision="205ba826a2282a4a5aa9bd3651e55ee4f2da1546",
- hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]),
- evaluation_splits=("validation",),
- few_shots_split="train",
- metric=(
- multilingual_quasi_exact_match_metric(Language.PORTUGUESE, "prefix"),
- multilingual_quasi_f1_score_metric(Language.PORTUGUESE),
- ),
- generation_size=400,
- stop_sequence=("\n",),
- )
-]
-
-
-# SQuAD-es: Spanish translation of the Stanford Question Answering Dataset
-# https://huggingface.co/datasets/ccasimiro/squad_es
-squad_es_tasks = [
- LightevalTaskConfig(
- name=f"squad_{Language.SPANISH.value}",
- prompt_function=get_qa_prompt_function(
- Language.SPANISH,
- lambda line: {
- "question": line["question"],
- "context": line["context"],
- "choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
- },
- ),
- suite=("lighteval",),
- hf_repo="ccasimiro/squad_es",
- hf_subset="v2.0.0",
- hf_filter=lambda line: any(len(ans) > 0 for ans in line["answers"]["text"]),
- evaluation_splits=("validation",),
- few_shots_split="train",
- metric=(
- multilingual_quasi_exact_match_metric(Language.SPANISH, "prefix"),
- multilingual_quasi_f1_score_metric(Language.SPANISH),
+ multilingual_quasi_exact_match_metric(
+ Language.RUSSIAN, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.RUSSIAN, normalize_pred=get_cot_answer_normalization(cot)),
),
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, cot),
)
+ for cot in [False, True]
]
-
# ARCD: Arabic Reading Comprehension Dataset.
# https://arxiv.org/pdf/1906.05394
arcd_tasks = [
LightevalTaskConfig(
- name=f"arcd_{Language.ARABIC.value}",
+ name=get_generative_task_name("arcd", Language.ARABIC, cot),
prompt_function=get_qa_prompt_function(
Language.ARABIC,
lambda line: {
@@ -1047,6 +992,7 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="hsseinmz/arcd",
@@ -1057,16 +1003,17 @@
multilingual_quasi_exact_match_metric(Language.ARABIC, "prefix"),
multilingual_quasi_f1_score_metric(Language.ARABIC),
),
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
+ stop_sequence=get_stop_sequence(Language.ARABIC, cot),
)
+ for cot in [False, True]
]
# KenSwQuAD: A question answering dataset for Kenyan Swahili.
# https://arxiv.org/abs/2205.02364
kenswquad_tasks = [
LightevalTaskConfig(
- name=f"kenswquad_{Language.SWAHILI.value}",
+ name=get_generative_task_name("kenswquad", Language.SWAHILI, cot),
prompt_function=get_qa_prompt_function(
Language.SWAHILI,
lambda line: {
@@ -1074,6 +1021,7 @@
"context": line["context"],
"choices": [line["answer"]],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="lighteval/KenSwQuAD",
@@ -1081,19 +1029,22 @@
evaluation_splits=("test",),
few_shots_split="validation",
metric=(
- multilingual_quasi_exact_match_metric(Language.SWAHILI, "prefix"),
- multilingual_quasi_f1_score_metric(Language.SWAHILI),
+ multilingual_quasi_exact_match_metric(
+ Language.SWAHILI, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.SWAHILI, normalize_pred=get_cot_answer_normalization(cot)),
),
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
+ stop_sequence=get_stop_sequence(Language.SWAHILI, cot),
)
+ for cot in [False, True]
]
# ChineseSquad: A reading comprehension dataset for Chinese.
# https://github.com/pluto-junzeng/ChineseSquad
chinese_squad_tasks = [
LightevalTaskConfig(
- name=f"chinese_squad_{Language.CHINESE.value}",
+ name=get_generative_task_name("chinese_squad", Language.CHINESE, cot),
prompt_function=get_qa_prompt_function(
Language.CHINESE,
lambda line: {
@@ -1101,6 +1052,7 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="lighteval/ChineseSquad",
@@ -1108,19 +1060,22 @@
evaluation_splits=("validation",),
few_shots_split="train",
metric=(
- multilingual_quasi_exact_match_metric(Language.CHINESE, "prefix"),
- multilingual_quasi_f1_score_metric(Language.CHINESE),
+ multilingual_quasi_exact_match_metric(
+ Language.CHINESE, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.CHINESE, normalize_pred=get_cot_answer_normalization(cot)),
),
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
+ stop_sequence=get_stop_sequence(Language.CHINESE, cot),
)
+ for cot in [False, True]
]
# CMRC 2018: A span-extraction machine reading comprehension dataset for Chinese.
# https://arxiv.org/abs/1810.07366
cmrc2018_tasks = [
LightevalTaskConfig(
- name=f"cmrc2018_{Language.CHINESE.value}",
+ name=get_generative_task_name("cmrc2018", Language.CHINESE, cot),
prompt_function=get_qa_prompt_function(
Language.CHINESE,
lambda line: {
@@ -1128,26 +1083,30 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="clue/clue",
hf_subset="cmrc2018",
evaluation_splits=("trial",),
few_shots_split="train",
- generation_size=400,
+ generation_size=get_cot_generaion_size(cot, 400),
metric=(
- multilingual_quasi_exact_match_metric(Language.CHINESE, "prefix"),
- multilingual_quasi_f1_score_metric(Language.CHINESE),
+ multilingual_quasi_exact_match_metric(
+ Language.CHINESE, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.CHINESE, normalize_pred=get_cot_answer_normalization(cot)),
),
- stop_sequence=("\n",),
+ stop_sequence=get_stop_sequence(Language.CHINESE, cot),
)
+ for cot in [False, True]
]
# IndicQA: A reading comprehension dataset for 11 Indian languages.
# https://arxiv.org/abs/2407.13522
indicqa_tasks = [
LightevalTaskConfig(
- name=f"indicqa_{language.value}",
+ name=get_generative_task_name("indicqa", language, cot),
prompt_function=get_qa_prompt_function(
language,
lambda line: {
@@ -1155,6 +1114,7 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="ai4bharat/IndicQA",
@@ -1166,12 +1126,14 @@
trust_dataset=True,
evaluation_splits=("test",),
hf_avail_splits=("test",),
- generation_size=400,
+ generation_size=get_cot_generaion_size(cot, 400),
metric=(
- multilingual_quasi_exact_match_metric(language, "prefix"),
- multilingual_quasi_f1_score_metric(language),
+ multilingual_quasi_exact_match_metric(
+ language, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)),
),
- stop_sequence=("\n",),
+ stop_sequence=get_stop_sequence(language, cot),
)
for language in [
Language.ASSAMESE,
@@ -1186,13 +1148,14 @@
Language.TAMIL,
Language.TELUGU,
]
+ for cot in [False, True]
]
# FQuAD v2: French Question Answering Dataset version 2.
# https://arxiv.org/abs/2002.06071
fquad_v2_tasks = [
LightevalTaskConfig(
- name=f"fquadv2_{Language.FRENCH.value}",
+ name=get_generative_task_name("fquadv2", Language.FRENCH, cot),
prompt_function=get_qa_prompt_function(
Language.FRENCH,
lambda line: {
@@ -1200,25 +1163,29 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="manu/fquad2_test",
hf_subset="default",
evaluation_splits=("test_hasAns",),
few_shots_split="valid_hasAns",
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=(
- multilingual_quasi_exact_match_metric(Language.FRENCH, "prefix"),
- multilingual_quasi_f1_score_metric(Language.FRENCH),
+ multilingual_quasi_exact_match_metric(
+ Language.FRENCH, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.FRENCH, normalize_pred=get_cot_answer_normalization(cot)),
),
+ stop_sequence=get_stop_sequence(Language.FRENCH, cot),
)
+ for cot in [False, True]
]
# TQuAD v2: Turkish Question Answering Dataset version 2.
tquad_v2_tasks = [
LightevalTaskConfig(
- name=f"tquadv2_{Language.TURKISH.value}",
+ name=get_generative_task_name("tquadv2", Language.TURKISH, cot),
prompt_function=get_qa_prompt_function(
Language.TURKISH,
lambda line: {
@@ -1226,19 +1193,23 @@
"context": line["context"],
"choices": [a["text"] for a in line["answers"]],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="erdometo/tquad2",
hf_subset="default",
evaluation_splits=("validation",),
few_shots_split="train",
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=(
- multilingual_quasi_exact_match_metric(Language.TURKISH, "prefix"),
- multilingual_quasi_f1_score_metric(Language.TURKISH),
+ multilingual_quasi_exact_match_metric(
+ Language.TURKISH, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.TURKISH, normalize_pred=get_cot_answer_normalization(cot)),
),
+ stop_sequence=get_stop_sequence(Language.TURKISH, cot),
)
+ for cot in [False, True]
]
# Other QA tasks for RC
@@ -1247,7 +1218,7 @@
# https://arxiv.org/abs/2003.05002
tydiqa_tasks = [
LightevalTaskConfig(
- name=f"tydiqa_{language.value}",
+ name=get_generative_task_name("tydiqa", language, cot),
prompt_function=get_qa_prompt_function(
language,
lambda line: {
@@ -1255,18 +1226,21 @@
"context": line["context"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="google-research-datasets/tydiqa",
hf_subset="secondary_task",
evaluation_splits=("validation",),
few_shots_split="train",
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=(
- multilingual_quasi_exact_match_metric(language, "prefix"),
- multilingual_quasi_f1_score_metric(language),
+ multilingual_quasi_exact_match_metric(
+ language, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)),
),
+ stop_sequence=get_stop_sequence(language, cot),
)
for language in [
Language.ENGLISH,
@@ -1281,6 +1255,7 @@
Language.TELUGU,
Language.THAI,
]
+ for cot in [False, True]
]
# C3: A Chinese Challenge Corpus for Cross-lingual and Cross-modal Tasks
@@ -1288,7 +1263,7 @@
# Paper: https://arxiv.org/abs/2004.05986
c3_tasks = [
LightevalTaskConfig(
- name=f"c3_{Language.CHINESE.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("c3", Language.CHINESE, formulation),
suite=("lighteval",),
prompt_function=get_mcq_prompt_function(
Language.CHINESE,
@@ -1304,19 +1279,17 @@
hf_subset="c3",
evaluation_splits=("validation",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.CHINESE,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# Other MCF tasks for RC
@@ -1326,7 +1299,7 @@
# Paper: https://aclanthology.org/2023.arabicnlp-1.21/
race_ar_task = [
LightevalTaskConfig(
- name=f"alghafa_race_{Language.ARABIC.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("alghafa_race", Language.ARABIC, formulation),
prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation),
suite=["lighteval"],
hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated",
@@ -1338,44 +1311,40 @@
evaluation_splits=["test"],
few_shots_split="validation",
trust_dataset=True,
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# SOQAL: A large-scale Arabic reading comprehension dataset.
# https://arxiv.org/abs/1906.05394
soqal_tasks = [
LightevalTaskConfig(
- name=f"soqal_{Language.ARABIC.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("soqal", Language.ARABIC, formulation),
hf_subset="multiple_choice_grounded_statement_soqal_task",
prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation),
evaluation_splits=["test"],
few_shots_split="validation",
suite=["lighteval"],
hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Native",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# MLQA (MultiLingual Question Answering) is a benchmark dataset for evaluating cross-lingual question answering performance.
@@ -1384,7 +1353,7 @@
# Paper: https://arxiv.org/abs/1910.07475
mlqa_tasks = [
LightevalTaskConfig(
- name=f"mlqa_{lang.value}",
+ name=get_generative_task_name("mlqa", lang, cot),
prompt_function=get_qa_prompt_function(
lang,
lambda line: {
@@ -1392,6 +1361,7 @@
"question": line["question"],
"choices": [ans for ans in line["answers"]["text"] if len(ans) > 0],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="facebook/mlqa",
@@ -1400,12 +1370,12 @@
trust_dataset=True,
evaluation_splits=("test",),
hf_avail_splits=["test"],
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=[
- multilingual_quasi_exact_match_metric(lang, "prefix"),
- multilingual_quasi_f1_score_metric(lang),
+ multilingual_quasi_exact_match_metric(lang, "prefix", normalize_pred=get_cot_answer_normalization(cot)),
+ multilingual_quasi_f1_score_metric(lang, normalize_pred=get_cot_answer_normalization(cot)),
],
+ stop_sequence=get_stop_sequence(lang, cot),
)
for lang in [
Language.ARABIC,
@@ -1415,13 +1385,18 @@
Language.HINDI,
Language.VIETNAMESE,
]
+ for cot in [False, True]
]
# Belebele: A large-scale reading comprehension dataset covering 122 languages.
# https://arxiv.org/abs/2308.16884
belebele_tasks = [
LightevalTaskConfig(
- name=f"belebele_{language}_{formulation.name.lower()}",
+ name=get_mcf_task_name(
+ "belebele",
+ language,
+ formulation,
+ ),
prompt_function=get_mcq_prompt_function(
iso_639_3_ind_to_iso_639_3_macro[LangCodeLanguage.get(language).to_alpha3()],
lambda line: {
@@ -1437,15 +1412,20 @@
hf_subset=language,
evaluation_splits=("test",),
hf_avail_splits=["test"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(
+ iso_639_3_ind_to_iso_639_3_macro[LangCodeLanguage.get(language).to_alpha3()],
+ formulation.cot,
+ ),
)
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for formulation in DEFAULT_FORMULATIONS
for language in [
"acm_Arab",
"arz_Arab",
@@ -1589,10 +1569,6 @@
*race_ar_task,
*belebele_tasks,
*c3_tasks,
- *squad_it_tasks,
- *squad_es_tasks,
- *faquad_tasks,
- *germanquad_tasks,
]
)
@@ -1670,7 +1646,7 @@
# Paper: https://arxiv.org/abs/2407.21783
meta_mmlu_tasks = [
LightevalTaskConfig(
- name=f"meta_mmlu_{language.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('meta_mmlu', language, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
language,
lambda line: {
@@ -1691,14 +1667,16 @@
),
evaluation_splits=("latest",),
hf_avail_splits=["latest"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for subset in MMLU_SUBSETS
for language in [
@@ -1710,18 +1688,14 @@
Language.PORTUGUESE,
Language.THAI,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# MLMM MMLU: Another multilingual version of MMLU
# Paper: https://github.com/nlp-uoregon/mlmm-evaluation
mlmm_mmlu_tasks = [
LightevalTaskConfig(
- name=f"mlmm_mmlu_{language.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('mlmm_mmlu', language, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
language,
lambda line: {
@@ -1739,14 +1713,16 @@
trust_dataset=True,
evaluation_splits=("test",),
few_shots_split="dev",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for subset in MMLU_SUBSETS
for language in [
@@ -1777,16 +1753,12 @@
Language.TELUGU,
Language.KANNADA,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
openai_mmlu_tasks = [
LightevalTaskConfig(
- name=f"openai_mmlu_{language[0].value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('openai_mmlu', language[0], formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
language[0],
lambda line: {
@@ -1803,14 +1775,16 @@
hf_avail_splits=["test"],
hf_filter=partial(lambda subset, x: x["Subject"].lower() == subset, subset),
hf_revision="038c7808122969ead7456361af05cb8f47d247f8",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language[0],
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(language[0], formulation.cot),
)
for subset in MMLU_SUBSETS
for language in [
@@ -1829,11 +1803,7 @@
(Language.YORUBA, "YO_NG"),
(Language.CHINESE, "ZH_CN"),
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# Translated MMLU using both professional and non-professional translators. Contains tags for cultural sensitivity.
@@ -1863,13 +1833,13 @@
lambda subset, sensitivity_label, x: x["subject"].lower() == subset
and (
sensitivity_label == "ALL" or sensitivity_label in x["cultural_sensitivity_label"].replace("-", "UNK")
- )
- and all(x[f"option_{opt}"] is not None and x[f"option_{opt}"].strip() for opt in "abcd"),
+ ),
subset,
sensitivity_label,
),
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
@@ -1914,11 +1884,7 @@
Language.YORUBA,
Language.ZULU,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
for sensitivity_label in ["ALL", "CA", "CS", "UNK"]
]
@@ -1936,7 +1902,7 @@
# From https://arxiv.org/abs/2406.03368. Human translated MMLU.
afri_mmlu_tasks = [
LightevalTaskConfig(
- name=f"afri_mmlu_{language.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('afri_mmlu', language, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
language,
lambda line: {
@@ -1954,14 +1920,16 @@
hf_filter=partial(lambda subset, line: line["subject"] == subset, subset),
evaluation_splits=("test",),
few_shots_split="dev",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for subset in AFRI_MMLU_SUBSETS
for language in [
@@ -1983,18 +1951,14 @@
Language.YORUBA,
# Language.ZULU,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# RUMMLU: Russian Massive Multitask Language Understanding
# Paper: https://arxiv.org/html/2401.04531v2
rummlu = [
LightevalTaskConfig(
- name=f"rummlu_{Language.RUSSIAN.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('rummlu', Language.RUSSIAN, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.RUSSIAN,
lambda line: {
@@ -2010,28 +1974,26 @@
hf_filter=lambda x: x["meta"]["domain"] == subset,
evaluation_splits=("public_test",),
hf_avail_splits=["public_test"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.RUSSIAN,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot),
)
for subset in MMLU_SUBSETS
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# MMLU Turkish: Turkish version of MMLU
# Translated using openai GPT
mmlu_turkish = [
LightevalTaskConfig(
- name=f"community_mmlu_{Language.TURKISH.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('community_mmlu', Language.TURKISH, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.TURKISH,
lambda line: {"question": line["question"], "choices": line["choices"], "gold_idx": int(line["answer"])},
@@ -2042,21 +2004,19 @@
hf_subset=subset,
evaluation_splits=("test",),
few_shots_split="dev",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.TURKISH,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot),
)
for subset in MMLU_SUBSETS
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# CMMLU: Chinese Massive Multitask Language Understanding
@@ -2134,7 +2094,7 @@
cmmlu_tasks = [
LightevalTaskConfig(
- name=f"cmmlu_{Language.CHINESE.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('cmmlu', Language.CHINESE, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.CHINESE,
lambda line: {
@@ -2149,21 +2109,19 @@
hf_subset=subset,
evaluation_splits=("test",),
few_shots_split="dev",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.CHINESE,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
)
for subset in CMMLU_SUBSETS
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# Arabic MMLU: Arabic version of MMLU
@@ -2214,7 +2172,7 @@
arabic_mmlu_tasks = [
LightevalTaskConfig(
- name=f"mmlu_{Language.ARABIC.value}_{formulation.name.lower()}:{normalize_subset(subset)}",
+ name=f"{get_mcf_task_name('mmlu', Language.ARABIC, formulation)}:{normalize_subset(subset)}",
prompt_function=get_mcq_prompt_function(
Language.ARABIC,
lambda line: {
@@ -2230,21 +2188,19 @@
hf_subset=subset,
evaluation_splits=("test",),
hf_avail_splits=["dev"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
for subset in ARABIC_MMLU_SUBSETS
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
@@ -2262,7 +2218,7 @@
turkish_mmlu_tasks = [
LightevalTaskConfig(
- name=f"mmlu_{Language.TURKISH.value}_{formulation.name.lower()}:{normalize_subset(subset)}",
+ name=f"{get_mcf_task_name('mmlu', Language.TURKISH, formulation)}:{normalize_subset(subset)}",
prompt_function=get_mcq_prompt_function(
Language.TURKISH,
lambda line: {
@@ -2277,21 +2233,19 @@
hf_subset=subset,
evaluation_splits=("test",),
few_shots_split="dev",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.TURKISH,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot),
)
for subset in TURKISH_MMLU_SUBSET
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -2323,7 +2277,7 @@
# github: https://github.com/nlp-uoregon/mlmm-evaluation
mlmm_arc_challenge_tasks = [
LightevalTaskConfig(
- name=f"mlmm_arc_{language.value}_{formulation.name.lower()}:challenge",
+ name=f"{get_mcf_task_name('mlmm_arc', language, formulation)}:challenge",
prompt_function=get_mcq_prompt_function(
language,
lambda line: {
@@ -2342,14 +2296,16 @@
trust_dataset=True,
evaluation_splits=("test",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.RUSSIAN,
@@ -2379,11 +2335,7 @@
Language.TELUGU,
Language.KANNADA,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# Arabic ARC Easy
@@ -2392,7 +2344,7 @@
# Paper: https://aclanthology.org/2023.arabicnlp-1.21/
arabic_ledarboard_arc_easy = [
LightevalTaskConfig(
- name=f"alghafa_arc_{Language.ARABIC.value}_{formulation.name.lower()}:easy",
+ name=f"{get_mcf_task_name('alghafa_arc', Language.ARABIC, formulation)}:easy",
prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation),
suite=["lighteval"],
hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated",
@@ -2401,25 +2353,23 @@
trust_dataset=True,
evaluation_splits=["test"],
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
lumi_arc = [
LightevalTaskConfig(
- name=f"lumi_arc_{language.value}_{formulation.name.lower()}:challenge",
+ name=f"{get_mcf_task_name('lumi_arc', language, formulation)}:challenge",
prompt_function=get_mcq_prompt_function(
language,
lambda line: {
@@ -2436,20 +2386,18 @@
hf_subset=standardize_tag(language.value),
evaluation_splits=["test"],
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
for language in [
Language.DANISH,
Language.GERMAN,
@@ -2469,7 +2417,7 @@
# Comes from the Turkish leaderboard
turkish_arc_tasks = [
LightevalTaskConfig(
- name=f"community_arc_{Language.TURKISH.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('community_arc', Language.TURKISH, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.TURKISH,
lambda line: {
@@ -2486,26 +2434,24 @@
hf_subset=f"ARC-{subset.capitalize()}",
evaluation_splits=("test",),
hf_avail_splits=["train"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.TURKISH,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
]
+ ([loglikelihood_acc_metric(normalization=LogProbPMINorm())] if subset == "challenge" else []), # type: ignore
),
+ stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot),
)
for subset in ["easy", "challenge"]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
hindi_arc_tasks = [
LightevalTaskConfig(
- name=f"community_arc_{Language.HINDI.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('community_arc', Language.HINDI, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.HINDI,
lambda line: {
@@ -2522,26 +2468,24 @@
hf_subset=f"ARC-{subset.capitalize()}",
evaluation_splits=("test",),
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.HINDI,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
]
+ ([loglikelihood_acc_metric(normalization=LogProbPMINorm())] if subset == "challenge" else []), # type: ignore
),
+ stop_sequence=get_stop_sequence(Language.HINDI, formulation.cot),
)
for subset in ["easy", "challenge"]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
arabic_arc_tasks = [
LightevalTaskConfig(
- name=f"alghafa_arc_{Language.ARABIC.value}_{formulation.name.lower()}:easy",
+ name=f"{get_mcf_task_name('alghafa_arc', Language.ARABIC, formulation)}:easy",
prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation),
suite=["lighteval"],
hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated",
@@ -2550,25 +2494,23 @@
evaluation_splits=["test"],
few_shots_split="validation",
few_shots_select="sequential",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
trust_dataset=True,
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
swahili_arc_tasks = [
LightevalTaskConfig(
- name=f"community_arc_{Language.SWAHILI.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('community_arc', Language.SWAHILI, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.SWAHILI,
lambda line: {
@@ -2588,21 +2530,19 @@
else "dc1df9df632d14c251594d9129fb833d2ca4429c",
evaluation_splits=("test",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.SWAHILI,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
]
+ ([loglikelihood_acc_metric(normalization=LogProbPMINorm())] if subset == "challenge" else []), # type: ignore
),
+ stop_sequence=get_stop_sequence(Language.SWAHILI, formulation.cot),
)
for subset in ["easy", "challenge"]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
@@ -2628,7 +2568,7 @@
# github: https://github.com/nlp-uoregon/mlmm-evaluation
mlmm_truthfulqa_tasks = [
LightevalTaskConfig(
- name=f"mlmm_truthfulqa_{language.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('mlmm_truthfulqa', language, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
language,
partial(
@@ -2648,13 +2588,15 @@
trust_dataset=True,
evaluation_splits=("validation",),
hf_avail_splits=["validation"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for subset in ["mc1", "mc2"]
for language in [
@@ -2692,18 +2634,14 @@
Language.VIETNAMESE,
Language.CHINESE,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# Turkish TruthfulQA
# Based on turkish leaderboard
turkish_truthfulqa = [
LightevalTaskConfig(
- name=f"community_truthfulqa_{Language.TURKISH.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('community_truthfulqa', Language.TURKISH, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.TURKISH,
partial(
@@ -2721,20 +2659,18 @@
hf_subset="default",
evaluation_splits=("validation",),
hf_avail_splits=["validation"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.TURKISH,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot),
)
for subset in ["mc1", "mc2"]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -2861,7 +2797,7 @@
exams_tasks = [
LightevalTaskConfig(
- name=f"exams_{language.value}_{formulation.name.lower()}:{normalize_subset(subject)}",
+ name=f"{get_mcf_task_name('exams', language, formulation)}:{normalize_subset(subject)}",
prompt_function=get_mcq_prompt_function(
language,
lambda line: {
@@ -2884,21 +2820,19 @@
),
evaluation_splits=("test",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in exams_subjects_by_lang.keys()
for subject in exams_subjects_by_lang[language]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# M3Exam: Multitask Multilingual Multimodal Evaluation Benchmark
@@ -2906,7 +2840,7 @@
# Paper: https://arxiv.org/abs/2306.05179
m3exams_tasks = [
LightevalTaskConfig(
- name=f"m3exams_{language.value}_{formulation.name.lower()}",
+ name=get_mcf_task_name("m3exams", language, formulation),
suite=("lighteval",),
prompt_function=get_mcq_prompt_function(
language,
@@ -2917,14 +2851,15 @@
hf_subset=LangCodeLanguage(standardize_tag(language.value)).language_name().lower(),
evaluation_splits=("test",),
few_shots_split="dev",
- generation_size=-1,
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.AFRIKAANS,
@@ -2937,11 +2872,7 @@
Language.THAI,
Language.VIETNAMESE,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# Thai Exams
@@ -2953,27 +2884,29 @@
thai_exams_tasks = [
LightevalTaskConfig(
- name=f"thai_exams_{Language.THAI.value}_{formulation.name.lower()}:{subset}",
- prompt_function=get_mcq_prompt_function(Language.THAI, thai_exams_adapter, formulation=formulation),
+ name=f"{get_mcf_task_name('thai_exams', Language.THAI, formulation)}:{subset}",
+ prompt_function=get_mcq_prompt_function(
+ Language.THAI,
+ thai_exams_adapter,
+ formulation=formulation,
+ ),
suite=("lighteval",),
hf_repo="scb10x/thai_exam",
hf_subset=subset,
evaluation_splits=("test",),
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.THAI,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.THAI, formulation.cot),
)
for subset in THAI_EXAMS_SUBSETS
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -2992,7 +2925,7 @@
# Paper: https://arxiv.org/abs/2110.08462
xcsqa_tasks = [
LightevalTaskConfig(
- name=f"xcsqa_{language.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('xcsqa', language, formulation)}",
prompt_function=get_mcq_prompt_function(
language,
lambda line: {
@@ -3004,20 +2937,22 @@
),
suite=("lighteval",),
hf_repo="INK-USC/xcsr",
- hf_subset=f"X-CSQA-{standardize_tag(language.value) if language != Language.JAPANESE else 'jap'}",
+ hf_subset=f"X-CSQA-{standardize_tag(language.value)}",
hf_filter=lambda x: all(
len(x["question"]["choices"]["text"][i].strip()) > 0 for i in range(len(x["question"]["choices"]["text"]))
),
evaluation_splits=("validation",),
hf_avail_splits=["validation"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ARABIC,
@@ -3037,11 +2972,7 @@
Language.VIETNAMESE,
Language.CHINESE,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -3059,29 +2990,30 @@
# Arabic version: https://aclanthology.org/2023.arabicnlp-1.21/
piqa_ar_tasks = [
LightevalTaskConfig(
- name=f"alghafa_piqa_{Language.ARABIC.value}_{formulation.name.lower()}",
- prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation),
+ name=f"{get_mcf_task_name('alghafa_piqa', Language.ARABIC, formulation)}",
+ prompt_function=get_mcq_prompt_function(
+ Language.ARABIC,
+ alghafa_adapter,
+ formulation=formulation,
+ ),
suite=["lighteval"],
hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated",
hf_revision="08663706ee7cab30c4b7dc1bb00042a3227ce1ff",
hf_subset="piqa_ar",
hf_avail_splits=["test", "validation"],
evaluation_splits=["test"],
- few_shots_split="validation",
trust_dataset=True,
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -3099,8 +3031,12 @@
# Arabic version: https://aclanthology.org/2023.arabicnlp-1.21/
openbook_ara_tasks = [
LightevalTaskConfig(
- name=f"alghafa_openbookqa_{Language.ARABIC.value}_{formulation.name.lower()}",
- prompt_function=get_mcq_prompt_function(Language.ARABIC, alghafa_adapter, formulation=formulation),
+ name=f"{get_mcf_task_name('alghafa_openbookqa', Language.ARABIC, formulation)}",
+ prompt_function=get_mcq_prompt_function(
+ Language.ARABIC,
+ alghafa_adapter,
+ formulation=formulation,
+ ),
suite=["lighteval"],
hf_repo="OALL/AlGhafa-Arabic-LLM-Benchmark-Translated",
hf_subset="openbook_qa_ext_ar",
@@ -3108,67 +3044,30 @@
trust_dataset=True,
evaluation_splits=["test"],
few_shots_split="validation",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
-# Spanish version of OpenBookQA from BSC Language Technology group
-# Dataset: https://huggingface.co/datasets/BSC-LT/openbookqa-es
-openbook_es_tasks = [
+# The Russian version is part of the MERA (Multilingual Enhanced Russian NLP Architectures) project.
+# Paper: https://arxiv.org/abs/2401.04531
+openbook_rus_tasks = [
LightevalTaskConfig(
- name=f"openbookqa_{Language.SPANISH.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('mera_openbookqa', Language.RUSSIAN, formulation)}",
prompt_function=get_mcq_prompt_function(
- Language.SPANISH,
+ Language.RUSSIAN,
lambda line: {
- "question": line["question_stem"],
- "choices": line["choices"]["text"],
- "gold_idx": LETTER_INDICES.index(line["answerKey"]),
- },
- formulation=formulation,
- ),
- suite=["lighteval"],
- hf_repo="BSC-LT/openbookqa-es",
- hf_subset="default",
- evaluation_splits=("test",),
- few_shots_split="validation",
- metric=get_metrics_for_formulation(
- formulation,
- [
- loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
- loglikelihood_acc_metric(normalization=LogProbCharNorm()),
- ],
- ),
- )
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
-]
-
-
-# The Russian version is part of the MERA (Multilingual Enhanced Russian NLP Architectures) project.
-# Paper: https://arxiv.org/abs/2401.04531
-openbook_rus_tasks = [
- LightevalTaskConfig(
- name=f"mera_openbookqa_{Language.RUSSIAN.value}_{formulation.name.lower()}",
- prompt_function=get_mcq_prompt_function(
- Language.RUSSIAN,
- lambda line: {
- "question": line["inputs"]["question"],
- "choices": [line["inputs"][f"option_{i.lower()}"] for i in LETTER_INDICES[:4]],
- "gold_idx": LETTER_INDICES.index(line["outputs"]),
+ "question": line["inputs"]["question"],
+ "choices": [line["inputs"][f"option_{i.lower()}"] for i in LETTER_INDICES[:4]],
+ "gold_idx": LETTER_INDICES.index(line["outputs"]),
},
formulation=formulation,
),
@@ -3177,26 +3076,23 @@
hf_subset="ruopenbookqa",
evaluation_splits=("train",),
hf_avail_splits=["train"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.RUSSIAN,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
[
*openbook_rus_tasks,
*openbook_ara_tasks,
- *openbook_es_tasks,
]
)
@@ -3209,7 +3105,7 @@
# Paper: https://aclanthology.org/2023.arabicnlp-1.21/
sciqa_ar_task = [
LightevalTaskConfig(
- name=f"alghafa_sciqa_{Language.ARABIC.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('alghafa_sciqa', Language.ARABIC, formulation)}",
prompt_function=get_mcq_prompt_function(
Language.ARABIC,
sciqa_adapter,
@@ -3223,20 +3119,18 @@
evaluation_splits=["test"],
few_shots_split="validation",
few_shots_select="sequential",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.ARABIC,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
trust_dataset=True,
+ stop_sequence=get_stop_sequence(Language.ARABIC, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -3253,7 +3147,7 @@
# MERA: https://github.com/ai-forever/MERA
mathlogicqa_rus_tasks = [
LightevalTaskConfig(
- name=f"mathlogic_qa_{Language.RUSSIAN.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('mathlogic_qa', Language.RUSSIAN, formulation)}",
prompt_function=get_mcq_prompt_function(
Language.RUSSIAN,
lambda line: {
@@ -3268,69 +3162,64 @@
hf_subset="mathlogicqa",
evaluation_splits=("train",),
hf_avail_splits=["train"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.RUSSIAN,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot),
)
- for formulation in [
- CFFormulation(),
- MCFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
cmath_tasks = [
LightevalTaskConfig(
- name=f"cmath_{Language.CHINESE.value}",
- prompt_function=get_qa_prompt_function(
+ name=get_generative_task_name("cmath", Language.CHINESE, cot),
+ prompt_function=get_math_qa_prompt_function(
Language.CHINESE,
lambda line: {
"question": line["question"],
"choices": [line["golden"]],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="weitianwen/cmath",
hf_subset="default",
evaluation_splits=("test",),
few_shots_split="validation",
- generation_size=25,
+ generation_size=get_cot_generaion_size(cot, 100),
metric=[
- multilingual_quasi_exact_match_metric(Language.CHINESE, "full"),
+ multilingual_extractive_match_metric(Language.CHINESE),
],
- stop_sequence=("\n",),
+ stop_sequence=get_stop_sequence(Language.CHINESE, cot),
)
+ for cot in (False, True)
]
mgsm_tasks = [
LightevalTaskConfig(
- name=f"mgsm_{language.value}",
- prompt_function=get_qa_prompt_function(
+ name=get_generative_task_name("mgsm", language, cot),
+ prompt_function=get_math_qa_prompt_function(
language,
- lambda line: {
- "question": line["question"],
- # The cot is available but we have no use:
- # line["answer"]
- "choices": [str(line["answer_number"])],
- },
+ mgsm_adapter,
+ cot=cot,
),
suite=("lighteval",),
hf_repo="juletxara/mgsm",
hf_subset=standardize_tag(language.value),
evaluation_splits=("test",),
few_shots_split="train",
- generation_size=25,
+ generation_size=get_cot_generaion_size(cot, 100),
metric=[
- multilingual_quasi_exact_match_metric(language, "full"),
+ multilingual_extractive_match_metric(language),
],
- stop_sequence=("\n",),
+ stop_sequence=get_stop_sequence(language, cot),
)
for language in [
- Language.ENGLISH,
Language.SPANISH,
Language.FRENCH,
Language.GERMAN,
@@ -3342,31 +3231,29 @@
Language.BENGALI,
Language.TELUGU,
]
+ for cot in (False, True)
]
+
# African MGSM: MGSM for African Languages
# From https://arxiv.org/abs/2406.03368. Human translated MGSM.
afri_mgsm_tasks = [
LightevalTaskConfig(
- name=f"afri_mgsm_{language.value}",
- prompt_function=get_qa_prompt_function(
+ name=get_generative_task_name("mgsm", language, cot),
+ prompt_function=get_math_qa_prompt_function(
language,
- lambda line: {
- "question": line["question"],
- # The cot is available but we have no use:
- # line["answer"]
- "choices": [str(line["answer_number"])],
- },
+ mgsm_adapter,
+ cot=cot,
),
suite=("lighteval",),
hf_repo="masakhane/afrimgsm",
hf_subset=language.value,
evaluation_splits=("test",),
few_shots_split="train",
- generation_size=25,
+ generation_size=get_cot_generaion_size(cot, 100),
metric=[
- multilingual_quasi_exact_match_metric(language, "full"),
+ multilingual_extractive_match_metric(language),
],
- stop_sequence=("\n",),
+ stop_sequence=get_stop_sequence(language, cot),
)
for language in [
Language.AMHARIC,
@@ -3387,13 +3274,353 @@
Language.YORUBA,
# Language.ZULU,
]
+ for cot in (False, True)
+]
+
+# MSVAMP - Math Word Problems (Translated from SVAMP using Google Translate)
+msvamp_tasks = [
+ LightevalTaskConfig(
+ name=get_generative_task_name("msvamp", language, cot),
+ prompt_function=get_math_qa_prompt_function(
+ language,
+ lambda line: {
+ "question": line["m_query"], # Using the translated version of the question
+ "choices": [float_to_choice_string(line["response"])], # The answer as a string
+ },
+ cot=cot,
+ ),
+ suite=("lighteval",),
+ hf_repo="Mathoctopus/MSVAMP",
+ hf_subset=standardize_tag(language.value),
+ evaluation_splits=("test",),
+ # Don't use balanced here as the biggest clusters are 1,2,3,4,5 which results
+ # in some llms just output 6
+ few_shots_select="random",
+ hf_avail_splits=["test"],
+ generation_size=get_cot_generaion_size(cot, 100),
+ metric=[
+ multilingual_extractive_match_metric(language),
+ ],
+ stop_sequence=get_stop_sequence(language, cot),
+ )
+ for language in [
+ Language.BENGALI,
+ Language.GERMAN,
+ Language.ENGLISH,
+ Language.SPANISH,
+ Language.FRENCH,
+ Language.JAPANESE,
+ Language.RUSSIAN,
+ Language.SWAHILI,
+ Language.THAI,
+ Language.CHINESE,
+ ]
+ for cot in (False, True)
+]
+
+
+# CMM-Math - Chinese Multimodal Math Dataset
+# CMM-Math is a comprehensive Chinese mathematical reasoning dataset containing over 28,000 high-quality samples
+# across 12 grade levels from elementary to high school. It includes multiple-choice and fill-in-the-blank questions
+# with detailed solutions. The dataset features both text-only and multimodal problems (with visual context in
+# questions/options). It consists of 22k+ training samples and 5k+ evaluation samples, designed to evaluate and
+# enhance mathematical reasoning capabilities of large language and multimodal models.
+# Note: Only the MCQ subset is implemented
+cmm_math_mc_tasks = [
+ LightevalTaskConfig(
+ name=get_mcf_task_name("cmm_math", Language.CHINESE, formulation),
+ prompt_function=get_mcq_prompt_function(
+ Language.CHINESE,
+ cmm_math_adapter,
+ formulation=formulation,
+ ),
+ suite=("lighteval",),
+ hf_repo="ecnu-icalk/cmm-math",
+ hf_subset="default",
+ hf_filter=lambda x: x["image"] == "[]"
+ and x["options"]
+ != "[]", # Only include examples without images (it's a string for some reason) and mcq questions
+ evaluation_splits=("test",),
+ few_shots_split="train",
+ metric=get_metrics_for_mcq_formulation(
+ formulation,
+ Language.CHINESE,
+ [
+ loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
+ loglikelihood_acc_metric(normalization=LogProbCharNorm()),
+ ],
+ ),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
+ )
+ for formulation in DEFAULT_FORMULATIONS
+]
+
+
+# Math23K - Chinese Math Word Problem Dataset
+# Math23K is a dataset containing 23,162 Chinese math word problems crawled from the internet.
+# Originally introduced in "Deep Neural Solver for Math Word Problems", it consists of math word
+# problems in Chinese text along with their corresponding numerical answers. The dataset is split
+# into training and test sets and is commonly used to evaluate mathematical reasoning capabilities
+# of language models on Chinese text. Each example contains the original Chinese problem text and
+# its corresponding numerical solution.
+math23k_tasks = [
+ LightevalTaskConfig(
+ name=get_generative_task_name("math23k", Language.CHINESE, cot),
+ prompt_function=get_math_qa_prompt_function(
+ Language.CHINESE,
+ lambda line: {
+ "question": line["original_text"], # Use the original Chinese text
+ "choices": [str(line["ans"])], # Answer has computed number while ans is symbolic
+ },
+ cot=cot,
+ ),
+ suite=("lighteval",),
+ hf_repo="Gxg/Math23K",
+ hf_subset="default",
+ evaluation_splits=("test",),
+ few_shots_split="train",
+ generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks like msvamp
+ metric=[multilingual_extractive_match_metric(Language.CHINESE)],
+ stop_sequence=get_stop_sequence(Language.CHINESE, cot),
+ )
+ for cot in (False, True)
+]
+
+# TAL-SCQ5K - Chinese Math Word Problem Dataset
+# TAL-SCQ5K is a high-quality mathematical competition dataset created by TAL Education Group,
+# consisting of 5K multiple-choice questions (3K training, 2K testing) covering math topics from
+# primary through high school levels. The dataset includes detailed solution steps and standardized
+# LaTeX expressions.
+
+tal_scq5k_tasks = [
+ LightevalTaskConfig(
+ name=get_mcf_task_name("tal_scq5k", Language.CHINESE, formulation),
+ prompt_function=get_mcq_prompt_function(
+ Language.CHINESE,
+ lambda line: {
+ "question": line["problem"].strip(),
+ "choices": [
+ opt[0]["content"] for opt in sorted(line["answer_option_list"], key=lambda x: x[0]["aoVal"])
+ ],
+ "gold_idx": LETTER_INDICES.index(line["answer_value"]),
+ },
+ formulation=formulation,
+ ),
+ suite=("lighteval",),
+ hf_repo="math-eval/TAL-SCQ5K",
+ hf_subset="default",
+ evaluation_splits=("test",),
+ few_shots_split="train",
+ metric=get_metrics_for_mcq_formulation(
+ formulation,
+ Language.CHINESE,
+ [
+ loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
+ loglikelihood_acc_metric(normalization=LogProbCharNorm()),
+ ],
+ ),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
+ )
+ for formulation in DEFAULT_FORMULATIONS
+]
+
+# MathQA-TR - Turkish Math Question Answering Dataset
+# Sourced from https://github.com/esingedik/Turkish-MWP-Corpora-and-Code
+# Translated using Google Translate
+# MWP is collected from MAWPS, ASDiv-A, SVAMP.
+mathqa_tr_tasks = [
+ LightevalTaskConfig(
+ name=get_generative_task_name("mathqa", Language.TURKISH, cot),
+ prompt_function=get_math_qa_prompt_function(
+ Language.TURKISH,
+ lambda line: {
+ "question": line["question"],
+ "choices": [line["answer"]],
+ },
+ cot=cot,
+ ),
+ suite=("lighteval",),
+ hf_repo="lighteval/MathQA-TR",
+ hf_subset="default",
+ evaluation_splits=("test",),
+ few_shots_split="train",
+ generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks
+ metric=[
+ multilingual_extractive_match_metric(Language.TURKISH),
+ ],
+ stop_sequence=get_stop_sequence(Language.TURKISH, cot),
+ )
+ for cot in (False, True)
+]
+
+mwp_tr_tasks = [
+ LightevalTaskConfig(
+ name=get_generative_task_name("mwp", Language.TURKISH, cot),
+ prompt_function=get_math_qa_prompt_function(
+ Language.TURKISH,
+ lambda line: {
+ "question": line["question"],
+ "choices": [line["answer"]],
+ },
+ cot=cot,
+ ),
+ suite=("lighteval",),
+ hf_repo="lighteval/MWP-TR",
+ hf_subset="default",
+ evaluation_splits=("test",),
+ few_shots_split="train",
+ generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks
+ metric=[
+ multilingual_extractive_match_metric(Language.TURKISH),
+ ],
+ stop_sequence=get_stop_sequence(Language.TURKISH, cot),
+ )
+ for cot in (False, True)
+]
+
+# MERA Arithmetic Tasks
+# Paper: https://arxiv.org/abs/2401.04531
+mera_arithmetic_tasks = [
+ LightevalTaskConfig(
+ name=get_generative_task_name("mera_arithmetic", Language.RUSSIAN, cot),
+ prompt_function=get_math_qa_prompt_function(
+ Language.RUSSIAN,
+ lambda line: {
+ "question": line["inputs"],
+ "choices": [str(line["outputs"])],
+ },
+ cot=cot,
+ ),
+ suite=("lighteval",),
+ hf_repo="ai-forever/MERA",
+ hf_subset=subset,
+ evaluation_splits=("public_test",)
+ if subset == "rumodar"
+ else ("train",), # MERA uses train split for evaluation
+ hf_avail_splits=["public_test"] if subset == "rumodar" else ["train"],
+ generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks
+ metric=[
+ Metrics.quasi_exact_match_math,
+ ],
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, cot),
+ )
+ for subset in ["rumodar", "rumultiar", "simplear"]
+ for cot in (False, True)
+]
+
+# QazUNTv2 Tasks - High school math problems in English and Russian
+# A bilingual dataset for evaluating LLMs on high school math problems covering:
+# - Algebra (436 problems)
+# - Logic (312 problems)
+# - Probability (163 problems)
+# Each problem includes multiple choice options and detailed solutions.
+# The dataset was manually curated, with English translations via Google Translate.
+# Paper: https://doi.org/10.17632/52vc6v4czj.1
+
+qazuntv2_tasks = [
+ LightevalTaskConfig(
+ name=get_mcf_task_name("qazuntv2", lang, formulation),
+ prompt_function=get_mcq_prompt_function(
+ lang,
+ qazuntv2_adapter,
+ formulation=formulation,
+ ),
+ suite=("lighteval",),
+ hf_repo="lighteval/QazUNTv2",
+ hf_subset=standardize_tag(lang.value),
+ hf_filter=lambda x: x["section"].lower() == subset,
+ evaluation_splits=("train",), # Dataset only has train split
+ hf_avail_splits=["train"],
+ metric=get_metrics_for_mcq_formulation(
+ formulation,
+ lang,
+ [
+ loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
+ loglikelihood_acc_metric(normalization=LogProbCharNorm()),
+ ],
+ ),
+ stop_sequence=get_stop_sequence(lang, formulation.cot),
+ )
+ for lang in [
+ Language.ENGLISH,
+ Language.RUSSIAN,
+ ]
+ for subset in ["algebra", "logic", "probability"]
+ for formulation in DEFAULT_FORMULATIONS
]
+
+
+# ArMATH - Arabic Math Word Problems
+# A dataset of 6,000 primary-school math word problems in Modern Standard Arabic (MSA).
+# Paper: https://github.com/reem-codes/ArMATH
+armath_tasks = [
+ LightevalTaskConfig(
+ name=get_generative_task_name("armath", Language.ARABIC, cot),
+ prompt_function=get_math_qa_prompt_function(
+ Language.ARABIC,
+ lambda line: {
+ "question": line["question"],
+ "choices": [float_to_choice_string(line["answer"])], # Evaluate the equation to get answer
+ },
+ cot=cot,
+ ),
+ suite=("lighteval",),
+ hf_repo="khalidalt/arMath",
+ hf_subset="default",
+ evaluation_splits=("test",), # Dataset only has test split
+ few_shots_split="validation",
+ generation_size=get_cot_generaion_size(cot, 100), # Similar to other math tasks
+ metric=[
+ multilingual_extractive_match_metric(Language.ARABIC),
+ ],
+ stop_sequence=get_stop_sequence(Language.ARABIC, cot),
+ )
+ for cot in (False, True)
+]
+
+# HAWP - Hindi Arithmetic Word Problems
+# A dataset containing 2.3k arithmetic word problems in Hindi, designed to evaluate
+# mathematical reasoning capabilities of language models. Each problem includes the
+# question text, equation, and numerical answer.
+hawp_tasks = [
+ LightevalTaskConfig(
+ name=get_generative_task_name("hawp", Language.HINDI, cot),
+ prompt_function=get_math_qa_prompt_function(
+ Language.HINDI,
+ lambda line: {
+ "question": line["Problem"],
+ "choices": [str(line["answer"])],
+ },
+ cot=cot,
+ ),
+ suite=("lighteval",),
+ hf_repo="lighteval/HAWP",
+ hf_subset="default",
+ evaluation_splits=("test",),
+ few_shots_split="dev",
+ generation_size=get_cot_generaion_size(cot, 100),
+ metric=[multilingual_extractive_match_metric(Language.HINDI, precision=6)],
+ stop_sequence=get_stop_sequence(Language.HINDI, cot),
+ )
+ for cot in (False, True)
+]
+
TASKS_TABLE.extend(
[
*cmath_tasks,
*mathlogicqa_rus_tasks,
*mgsm_tasks,
*afri_mgsm_tasks,
+ *armath_tasks,
+ *msvamp_tasks,
+ *cmm_math_mc_tasks,
+ *math23k_tasks,
+ *tal_scq5k_tasks,
+ *mathqa_tr_tasks,
+ *mwp_tr_tasks,
+ *mera_arithmetic_tasks,
+ *qazuntv2_tasks,
+ *hawp_tasks,
]
)
@@ -3417,7 +3644,7 @@
agieval_tasks_zh = [
LightevalTaskConfig(
- name=f"agieval_{Language.CHINESE.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('agieval', Language.CHINESE, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.CHINESE,
partial(
@@ -3433,21 +3660,19 @@
evaluation_splits=("test",),
hf_avail_splits=["test"],
few_shots_split=None,
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.CHINESE,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
loglikelihood_acc_metric(normalization=LogProbPMINorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
)
for subset in CHINESE_AGIEVAL_SUBSET
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
# C-Eval: Chinese Evaluation suite
# Similar to MMLu but with different categories
@@ -3467,7 +3692,6 @@
"high_school_mathematics",
"high_school_physics",
"high_school_chemistry",
- "high_school_biology",
"middle_school_mathematics",
"middle_school_biology",
"middle_school_physics",
@@ -3509,7 +3733,7 @@
ceval_tasks = [
LightevalTaskConfig(
- name=f"ceval_{Language.CHINESE.value}_{formulation.name.lower()}:{subset}",
+ name=f"{get_mcf_task_name('ceval', Language.CHINESE, formulation)}:{subset}",
prompt_function=get_mcq_prompt_function(
Language.CHINESE,
partial(
@@ -3524,92 +3748,18 @@
hf_subset=subset,
evaluation_splits=("val",),
few_shots_split="dev",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.CHINESE,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.CHINESE, formulation.cot),
)
for subset in CEVAL_SUBSET
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
-]
-
-
-# OAB Exams: A collection of questions from the Brazilian Bar Association exam
-# The exam is required for anyone who wants to practice law in Brazil
-# Dataset: https://huggingface.co/datasets/eduagarcia/oab_exams
-oab_exams_tasks = [
- LightevalTaskConfig(
- name=f"oab_exams_{Language.PORTUGUESE.value}_{formulation.name.lower()}",
- prompt_function=get_mcq_prompt_function(
- Language.PORTUGUESE,
- lambda line: {
- "question": line["question"],
- "choices": line["choices"]["text"],
- "gold_idx": LETTER_INDICES.index(line["answerKey"]),
- },
- formulation=formulation,
- ),
- suite=("lighteval",),
- hf_repo="eduagarcia/oab_exams",
- hf_subset="default",
- evaluation_splits=("train",),
- hf_avail_splits=["train"],
- metric=get_metrics_for_formulation(
- formulation,
- [
- loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
- loglikelihood_acc_metric(normalization=LogProbCharNorm()),
- ],
- ),
- )
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
-]
-
-# ENEM (Exame Nacional do Ensino Médio) is a standardized Brazilian national secondary
-# education examination. The exam is used both as a university admission test and as a
-# high school evaluation test.
-# Dataset: https://huggingface.co/datasets/maritaca-ai/enem
-enem_tasks = [
- LightevalTaskConfig(
- name=f"enem_{Language.PORTUGUESE.value}_{formulation.name.lower()}:{year}",
- prompt_function=get_mcq_prompt_function(
- Language.PORTUGUESE,
- partial(
- enem_adapter,
- Language.PORTUGUESE,
- ),
- formulation=formulation,
- ),
- suite=("lighteval",),
- hf_repo="maritaca-ai/enem",
- hf_subset=year,
- evaluation_splits=("train",),
- hf_avail_splits=["train"],
- metric=get_metrics_for_formulation(
- formulation,
- [
- loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
- loglikelihood_acc_metric(normalization=LogProbCharNorm()),
- ],
- ),
- )
- for year in ["2022", "2023", "2024"]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
@@ -3619,7 +3769,7 @@
# MERA: https://github.com/ai-forever/MERA
worldtree_rus_tasks = [
LightevalTaskConfig(
- name=f"mera_worldtree_{Language.RUSSIAN.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('mera_worldtree', Language.RUSSIAN, formulation)}",
prompt_function=get_mcq_prompt_function(
Language.RUSSIAN,
lambda line: {
@@ -3634,19 +3784,17 @@
hf_subset="ruworldtree",
evaluation_splits=("train",),
hf_avail_splits=["train"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ Language.RUSSIAN,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -3654,8 +3802,6 @@
*agieval_tasks_zh,
*worldtree_rus_tasks,
*ceval_tasks,
- *oab_exams_tasks,
- *enem_tasks,
]
)
@@ -3663,20 +3809,22 @@
# ------------------------------- Continuation Tasks ------------------------------- #
xcodah_tasks = [
LightevalTaskConfig(
- name=f"xcodah_{language.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('xcodah', language, formulation)}",
prompt_function=get_mcq_prompt_function(language, partial(xcodah_adapter, language), formulation=formulation),
suite=("lighteval",),
hf_repo="INK-USC/xcsr",
- hf_subset=f"X-CODAH-{standardize_tag(language.value) if language != Language.JAPANESE else 'jap'}",
+ hf_subset=f"X-CODAH-{standardize_tag(language.value)}",
evaluation_splits=("validation",),
hf_avail_splits=["validation"],
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ language,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ARABIC,
@@ -3696,16 +3844,12 @@
Language.VIETNAMESE,
Language.CHINESE,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
xstory_tasks = [
LightevalTaskConfig(
- name=f"xstory_cloze_{lang.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('xstory_cloze', lang, formulation)}",
prompt_function=get_continuation_prompt_function(
lang,
partial(
@@ -3730,13 +3874,15 @@
hf_subset=standardize_tag(lang.value),
evaluation_splits=["eval"],
few_shots_split="train",
- metric=get_metrics_for_formulation(
+ metric=get_metrics_for_mcq_formulation(
formulation,
+ lang,
[
loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
loglikelihood_acc_metric(normalization=LogProbCharNorm()),
],
),
+ stop_sequence=get_stop_sequence(lang, formulation.cot),
)
for lang in [
Language.RUSSIAN,
@@ -3750,11 +3896,7 @@
Language.BASQUE,
Language.BURMESE,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -3768,20 +3910,27 @@
xwinograd_tasks = [
LightevalTaskConfig(
- name=f"xwinograd_{language.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('xwinograd', language, formulation)}",
suite=("lighteval",),
prompt_function=get_continuation_prompt_function(
- language, partial(winogrand_adapter, language), formulation=formulation
+ language,
+ partial(winogrand_adapter, language),
+ formulation=formulation,
),
hf_repo="Muennighoff/xwinograd",
- hf_subset=standardize_tag(language.value) if language != Language.JAPANESE else "jp",
+ hf_subset=standardize_tag(language.value),
evaluation_splits=("test",),
hf_avail_splits=["test"],
- metric=[
- loglikelihood_acc_metric(normalization=None),
- loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
- loglikelihood_acc_metric(normalization=LogProbCharNorm()),
- ],
+ metric=get_metrics_for_mcq_formulation(
+ formulation,
+ language,
+ [
+ loglikelihood_acc_metric(normalization=None),
+ loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
+ loglikelihood_acc_metric(normalization=LogProbCharNorm()),
+ ],
+ ),
+ stop_sequence=get_stop_sequence(language, formulation.cot),
)
for language in [
Language.ENGLISH,
@@ -3791,35 +3940,34 @@
Language.RUSSIAN,
Language.CHINESE,
]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
winograd_turkish_task = [
LightevalTaskConfig(
- name=f"community_xwinograd_{Language.TURKISH.value}_{formulation.name.lower()}",
+ name=f"{get_mcf_task_name('community_xwinograd', Language.TURKISH, formulation)}",
suite=("lighteval",),
prompt_function=get_continuation_prompt_function(
- Language.TURKISH, partial(winogrand_adapter, Language.TURKISH), formulation=formulation
+ Language.TURKISH,
+ partial(winogrand_adapter, Language.TURKISH),
+ formulation=formulation,
),
hf_repo="malhajar/winogrande-tr-v0.2",
hf_subset="default",
evaluation_splits=("validation",),
few_shots_split="train",
- metric=[
- loglikelihood_acc_metric(normalization=None),
- loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
- loglikelihood_acc_metric(normalization=LogProbCharNorm()),
- ],
+ metric=get_metrics_for_mcq_formulation(
+ formulation,
+ Language.TURKISH,
+ [
+ loglikelihood_acc_metric(normalization=None),
+ loglikelihood_acc_metric(normalization=LogProbTokenNorm()),
+ loglikelihood_acc_metric(normalization=LogProbCharNorm()),
+ ],
+ ),
+ stop_sequence=get_stop_sequence(Language.TURKISH, formulation.cot),
)
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
- ]
+ for formulation in DEFAULT_FORMULATIONS
]
TASKS_TABLE.extend(
@@ -3844,8 +3992,8 @@
mkqa_tasks = [
LightevalTaskConfig(
- name=f"mkqa_{language.value}:{subset}",
- prompt_function=get_qa_prompt_function(language, partial(get_mkqa_adapter, language)),
+ name=get_generative_task_name("mkqa", language, cot),
+ prompt_function=get_qa_prompt_function(language, partial(get_mkqa_adapter, language), cot=cot),
suite=("lighteval",),
hf_repo="apple/mkqa",
hf_subset="mkqa",
@@ -3861,15 +4009,17 @@
trust_dataset=True,
evaluation_splits=("train",),
hf_avail_splits=["train"],
- stop_sequence=("\n",),
metric=[
- multilingual_quasi_exact_match_metric(language, "prefix"),
- multilingual_quasi_f1_score_metric(language),
+ multilingual_quasi_exact_match_metric(
+ language, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(language, normalize_pred=get_cot_answer_normalization(cot)),
]
if subset in ["entity", "long_answer", "short_phrase"]
else [
- multilingual_quasi_exact_match_metric(language, "full"),
+ multilingual_quasi_exact_match_metric(language, "full", normalize_pred=get_cot_answer_normalization(cot)),
],
+ stop_sequence=get_stop_sequence(language, cot),
)
for subset in MKQA_TASK_TO_ID.keys()
for language in [
@@ -3900,6 +4050,7 @@
# Language.CHINESE_HONG_KONG,
# Language.CHINESE_TRADITIONAL,
]
+ for cot in (False, True)
]
mintaka_tasks = [
@@ -3911,18 +4062,19 @@
"question": line["question"],
"choices": [line["answerText"]],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="AmazonScience/mintaka",
hf_subset=standardize_tag(lang.value),
evaluation_splits=("test",),
few_shots_split="train",
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=[
multilingual_quasi_exact_match_metric(lang, "prefix"),
multilingual_quasi_f1_score_metric(lang),
],
+ stop_sequence=get_stop_sequence(lang, cot=cot),
)
for lang in [
Language.ARABIC,
@@ -3935,55 +4087,64 @@
Language.JAPANESE,
Language.PORTUGUESE,
]
+ for cot in (False, True)
]
french_triviqa_tasks = [
LightevalTaskConfig(
- name=f"community_triviaqa_{Language.FRENCH.value}",
+ name=get_generative_task_name("community_triviaqa", Language.FRENCH, cot),
prompt_function=get_qa_prompt_function(
Language.FRENCH,
lambda line: {
"question": line["Question"],
"choices": [line["Answer"]],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="manu/french-trivia",
hf_subset="default",
evaluation_splits=("train",),
hf_avail_splits=["train"],
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=[
- multilingual_quasi_exact_match_metric(Language.FRENCH, "prefix"),
- multilingual_quasi_f1_score_metric(Language.FRENCH),
+ multilingual_quasi_exact_match_metric(
+ Language.FRENCH, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.FRENCH, normalize_pred=get_cot_answer_normalization(cot)),
],
+ stop_sequence=get_stop_sequence(Language.FRENCH, cot),
)
+ for cot in (False, True)
]
chegeka_tasks = [
LightevalTaskConfig(
- name=f"chegeka_{Language.RUSSIAN.value}",
+ name=get_generative_task_name("chegeka", Language.RUSSIAN, cot),
prompt_function=get_qa_prompt_function(
Language.RUSSIAN,
lambda line: {
"question": line["inputs"]["text"],
"choices": [line["outputs"]],
},
+ cot=cot,
),
suite=("lighteval",),
hf_repo="ai-forever/MERA",
hf_subset="chegeka",
evaluation_splits=("train",),
hf_avail_splits=["train"],
- generation_size=400,
- stop_sequence=("\n",),
+ generation_size=get_cot_generaion_size(cot, 400),
metric=[
- multilingual_quasi_exact_match_metric(Language.RUSSIAN, "prefix"),
- multilingual_quasi_f1_score_metric(Language.RUSSIAN),
+ multilingual_quasi_exact_match_metric(
+ Language.RUSSIAN, "prefix", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ multilingual_quasi_f1_score_metric(Language.RUSSIAN, normalize_pred=get_cot_answer_normalization(cot)),
],
+ stop_sequence=get_stop_sequence(Language.RUSSIAN, cot),
)
+ for cot in (False, True)
]
TASKS_TABLE.extend(
@@ -4061,7 +4222,7 @@
acva_tasks = [
LightevalTaskConfig(
- name=f"acva_{Language.ARABIC.value}:{subset}",
+ name=get_generative_task_name("acva", Language.ARABIC, cot),
prompt_function=get_boolq_prompt_function(
Language.ARABIC,
lambda line: {
@@ -4075,11 +4236,17 @@
hf_subset=subset,
evaluation_splits=("test",),
few_shots_split="validation",
- metric=[multilingual_quasi_exact_match_metric(Language.ARABIC, "full"), loglikelihood_acc_metric()],
- generation_size=5,
+ metric=[
+ multilingual_quasi_exact_match_metric(
+ Language.ARABIC, "full", normalize_pred=get_cot_answer_normalization(cot)
+ ),
+ loglikelihood_acc_metric(),
+ ],
+ generation_size=get_cot_generaion_size(cot, 5),
stop_sequence=("\n",),
)
for subset in ACVA_SUBSET
+ for cot in (False, True)
]
@@ -4368,7 +4535,7 @@ def flores_adapter(lang1, lang2):
source_language=Language(manage_duplicate_language_codes(lang1.split("_")[0])),
target_language=Language(manage_duplicate_language_codes(lang2.split("_")[0])),
adapter=flores_adapter(lang1, lang2),
- formulation=CFFormulation(),
+ formulation=CFFormulation(cot=cot),
),
suite=("lighteval",),
hf_repo="facebook/flores",
@@ -4377,13 +4544,19 @@ def flores_adapter(lang1, lang2):
evaluation_splits=["devtest"],
few_shots_split="dev",
few_shots_select=None,
- generation_size=300,
- metric=[Metrics.chrf_plus, Metrics.bleu, Metrics.bleu_1, Metrics.bleu_4],
- stop_sequence=["\n"],
+ generation_size=get_cot_generaion_size(cot, 300),
+ metric=[
+ translation_metric(metric_name="chrf++", normalize_pred=get_cot_answer_normalization(cot)),
+ translation_metric(metric_name="bleu", normalize_pred=get_cot_answer_normalization(cot)),
+ translation_metric(metric_name="bleu_1", normalize_pred=get_cot_answer_normalization(cot)),
+ translation_metric(metric_name="bleu_4", normalize_pred=get_cot_answer_normalization(cot)),
+ ],
+ stop_sequence=get_stop_sequence(Language.ENGLISH, cot),
trust_dataset=True,
version=0,
)
for (lang1, lang2) in combinations(flores_200_languages, 2)
+ for cot in (False, True)
]
TASKS_TABLE.extend(
diff --git a/src/lighteval/tasks/multilingual/utils/adapters_utils.py b/src/lighteval/tasks/multilingual/utils/adapters_utils.py
index 2e7f27dda..e30bbf512 100644
--- a/src/lighteval/tasks/multilingual/utils/adapters_utils.py
+++ b/src/lighteval/tasks/multilingual/utils/adapters_utils.py
@@ -133,3 +133,8 @@ def extract_answer(acc: tuple[str, int, list[str]], symbol: str) -> tuple[str, i
answer[: len(prefix)]: answer[len(prefix) :].strip() for answer, prefix in zip(found_answers, answer_prefixes)
}
return last_index, prefix_answer_dict
+
+
+def float_to_choice_string(answer: float) -> str:
+ answer = float(answer)
+ return str(int(answer)) if answer.is_integer() else str(answer)
diff --git a/src/lighteval/tasks/multilingual/utils/task_utils.py b/src/lighteval/tasks/multilingual/utils/task_utils.py
index d8e73dac8..29053ec2f 100644
--- a/src/lighteval/tasks/multilingual/utils/task_utils.py
+++ b/src/lighteval/tasks/multilingual/utils/task_utils.py
@@ -21,22 +21,74 @@
# SOFTWARE.
-from lighteval.metrics.dynamic_metrics import loglikelihood_acc_metric
+import re
+from typing import Callable
+
+from lighteval.metrics.dynamic_metrics import (
+ IndicesExtractionConfig,
+ loglikelihood_acc_metric,
+ multilingual_extractive_match_metric,
+)
from lighteval.metrics.utils.metric_utils import Metric
from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation
+from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
+from lighteval.utils.language import Language
def normalize_subset(subset: str) -> str:
return subset.replace(" ", "_").replace("(", "").replace(")", "").lower()
-def get_metrics_for_formulation(formulation: Formulation, metrics: list[Metric]) -> list[Metric]:
+def get_metrics_for_mcq_formulation(
+ formulation: Formulation, language: Language, metrics: list[Metric]
+) -> list[Metric]:
"""
Choose the appropriate metrics for the given formulation otherwise fallback to the original metrics.
"""
match formulation:
- #
- case MCFFormulation(choice_prefix="Letters"):
+ case MCFFormulation(choice_prefix="Letters" | "Numbers", cot=False):
return [loglikelihood_acc_metric(normalization=None)]
+ # In case of NativeLetters we can't use just acc_metric, because the letters can be made of multiple tokens
+ case MCFFormulation(cot=False):
+ return metrics
+ case MCFFormulation(cot=True):
+ return [
+ multilingual_extractive_match_metric(
+ language,
+ gold_extraction_target=(IndicesExtractionConfig(prefix_for_extraction=formulation.choice_prefix),),
+ ),
+ ]
case _:
return metrics
+
+
+def get_cot_generaion_size(cot: bool, generation_size: int) -> int | None:
+ return None if cot else generation_size
+
+
+def get_stop_sequence(language: Language, cot: bool) -> list[str] | None:
+ stop_sequence = ["\n"] if cot else []
+ try:
+ trans = TRANSLATION_LITERALS[language]
+ # Ensure it's on a new line as otherwise LLM's sometimes like to generate:
+ # "**Répondez à la" or "1. **Comprendre la" in their cot generations
+ return [
+ f"\n{trans.question_word}{trans.colon}",
+ f"\n{trans.question_word.capitalize()}{trans.colon}",
+ ] + stop_sequence
+ except (AttributeError, KeyError):
+ return stop_sequence
+
+
+def get_cot_answer_normalization(cot: bool = False) -> Callable[[str], str] | None:
+ if not cot:
+ return None
+
+ compiled_b_regex = re.compile(r"(.*?)")
+
+ def bb_normalizer(text: str) -> str:
+ matches = compiled_b_regex.findall(text)
+ last = matches[-1] if matches else text
+ return last
+
+ return bb_normalizer
diff --git a/src/lighteval/tasks/templates/boolq.py b/src/lighteval/tasks/templates/boolq.py
index 09959e874..e0b19575b 100644
--- a/src/lighteval/tasks/templates/boolq.py
+++ b/src/lighteval/tasks/templates/boolq.py
@@ -46,6 +46,7 @@ class BoolQInput(TypedDict):
answer: bool
instruction: NotRequired[str]
context: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class BoolQAdapter(TypedDict):
@@ -62,6 +63,7 @@ class BoolQAdapter(TypedDict):
answer: str
instruction: NotRequired[str]
context: NotRequired[str]
+ few_shot_cot: NotRequired[str]
def get_boolq_prompt_function(
@@ -90,6 +92,7 @@ def get_boolq_prompt_function(
"context": "context",
"instruction": "instruction",
"gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
},
formulation,
)
@@ -105,11 +108,13 @@ def boolq_prompt(
choices = [translation_literals.yes, translation_literals.no]
return mcq_prompt_fn(
{
+ **{x: line[x] for x in line if x.startswith("__")},
"instruction": input_data.get("instruction", ""),
"question": input_data["question"],
"context": input_data.get("context", ""),
"choices": choices,
"gold_idx": 0 if input_data["answer"] else 1,
+ "few_shot_cot": input_data.get("few_shot_cot", ""),
},
task_name,
)
diff --git a/src/lighteval/tasks/templates/continuation.py b/src/lighteval/tasks/templates/continuation.py
index c9cd5d1bc..9266986c3 100644
--- a/src/lighteval/tasks/templates/continuation.py
+++ b/src/lighteval/tasks/templates/continuation.py
@@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import logging
from typing import Callable
from typing_extensions import NotRequired, TypedDict
@@ -44,9 +45,11 @@
from lighteval.utils.utils import as_list
+logger = logging.getLogger(__name__)
+
CONTINUATION_QUERY_CF = "{instruction}{context}"
-CONTINUATION_QUERY_MCF = "{instruction}{context}\n{options}{answer_word}{colon}"
+CONTINUATION_QUERY_MCF = "{instruction}{context}\n\n{options_word}{colon}\n{options}{answer_word}{colon}"
# Defined for type hinting only
@@ -64,6 +67,7 @@ class ContinuationInput(TypedDict):
continuations: list[str]
gold_idx: list[int] | int
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class ContinuationDictAdapter(TypedDict):
@@ -80,6 +84,7 @@ class ContinuationDictAdapter(TypedDict):
continuations: str
gold_idx: str
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
def get_continuation_prompt_function(
@@ -107,6 +112,8 @@ def get_continuation_prompt_function(
*MCF*
Context
+
+ Options:
A. Continuation 1
B. Continuation 2
C. Continuation 3
@@ -126,13 +133,31 @@ def get_continuation_prompt_function(
adapter_fn = create_adapter_from_dict(adapter)
translation_literals = TRANSLATION_LITERALS[language]
+ WARNED_ABOUT_COT_INSTRUCTION = False
+
def prepare_prompt(line: dict):
cont_input = adapter_fn(line)
if cont_input is None:
return None
instruction_val = cont_input.get("instruction")
- instruction = f"{instruction_val}\n" if instruction_val else ""
+ if formulation.cot and not instruction_val:
+ if not isinstance(formulation, MCFFormulation) or formulation.choice_prefix not in [
+ "Letters",
+ "NativeLetters",
+ ]:
+ raise ValueError(
+ "You are using a COT with a unsupported formulation. Either use MCF formulation or provide an instruction."
+ )
+
+ instruction_val = f"{translation_literals.continuation_mcf_instruction}\n{translation_literals.default_formatting_instruction}"
+ nonlocal WARNED_ABOUT_COT_INSTRUCTION
+ if not WARNED_ABOUT_COT_INSTRUCTION:
+ logger.warning(
+ f" You are using a COT with MCF formulation but did not provide an instruction. Defaulting to {instruction_val}"
+ )
+ WARNED_ABOUT_COT_INSTRUCTION = True
+ instruction = f"{instruction_val}\n\n" if instruction_val else ""
context = (
f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}"
@@ -182,16 +207,31 @@ def prompt_fn_mcf(line, task_name: str):
options = build_choices(continuations, formulation, translation_literals)
options = f"{options}\n" if options else ""
- answers = build_answers(continuations, formulation, translation_literals)
-
- answer_word = capitalize(translation_literals.answer)
+ answer_word = capitalize(
+ translation_literals.answer if not formulation.cot else translation_literals.answer_cot
+ )
+ options_word = capitalize(translation_literals.options_word)
query = CONTINUATION_QUERY_MCF.format(
instruction=instruction,
context=context,
options=options,
answer_word=answer_word,
colon=translation_literals.colon,
+ options_word=options_word,
+ )
+
+ few_shot_cot = cont_input.get("few_shot_cot", None)
+ is_few_shot = line.get("__few_shots", False)
+ if formulation.cot and few_shot_cot and is_few_shot:
+ continuations = [
+ capitalize(fix_ending_punct(answer, translation_literals)) for answer in as_list(few_shot_cot)
+ ]
+ answers = build_answers(
+ continuations,
+ formulation,
+ translation_literals,
+ is_few_shot=is_few_shot and bool(few_shot_cot),
)
return Doc(
diff --git a/src/lighteval/tasks/templates/copa.py b/src/lighteval/tasks/templates/copa.py
index a4d82c4de..46f25201d 100644
--- a/src/lighteval/tasks/templates/copa.py
+++ b/src/lighteval/tasks/templates/copa.py
@@ -53,6 +53,7 @@ class COPAInput(TypedDict):
continuations: list[str]
gold_idx: int | list[int]
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class COPAAdapter(TypedDict):
@@ -71,6 +72,7 @@ class COPAAdapter(TypedDict):
continuations: str
gold_idx: str
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
def get_copa_prompt_function(
@@ -113,7 +115,15 @@ def get_copa_prompt_function(
"""
adapter_fn = create_adapter_from_dict(adapter)
continuation_prompt_fn = get_continuation_prompt_function(
- language, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation
+ language,
+ {
+ "context": "context",
+ "continuations": "continuations",
+ "gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ formulation,
)
translation_literals = TRANSLATION_LITERALS[language]
@@ -140,10 +150,12 @@ def copa_prompt(
return continuation_prompt_fn(
{
+ **{x: line[x] for x in line if x.startswith("__")},
"instruction": input_data.get("instruction", ""),
"context": context,
"continuations": input_data["continuations"],
"gold_idx": input_data["gold_idx"],
+ "few_shot_cot": input_data.get("few_shot_cot", ""),
},
task_name,
)
diff --git a/src/lighteval/tasks/templates/hellaswag.py b/src/lighteval/tasks/templates/hellaswag.py
index 79108a9cd..2f6f0b07d 100644
--- a/src/lighteval/tasks/templates/hellaswag.py
+++ b/src/lighteval/tasks/templates/hellaswag.py
@@ -49,6 +49,7 @@ class HellaswagInput(TypedDict):
instruction: NotRequired[str]
activity_label: NotRequired[str]
ctx_b: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class HellaswagAdapter(TypedDict):
@@ -58,6 +59,7 @@ class HellaswagAdapter(TypedDict):
instruction: NotRequired[str]
activity_label: NotRequired[str]
ctx_b: NotRequired[str]
+ few_shot_cot: NotRequired[str]
def get_hellaswag_prompt_function(
@@ -107,7 +109,15 @@ def join_ctxs(ctx_a, ctx_b):
adapter_fn = create_adapter_from_dict(adapter)
continuation_prompt_fn = get_continuation_prompt_function(
- language, {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"}, formulation
+ language,
+ {
+ "context": "context",
+ "continuations": "continuations",
+ "gold_idx": "gold_idx",
+ "instruction": "instruction",
+ "few_shot_cot": "few_shot_cot",
+ },
+ formulation,
)
def hellaswag_prompt(
@@ -145,10 +155,12 @@ def hellaswag_prompt(
return continuation_prompt_fn(
{
+ **{x: line[x] for x in line if x.startswith("__")},
"instruction": input_data.get("instruction", ""),
"context": full_context,
"continuations": choices,
"gold_idx": input_data["gold_idx"],
+ "few_shot_cot": input_data.get("few_shot_cot", ""),
},
task_name,
)
diff --git a/src/lighteval/tasks/templates/math_qa.py b/src/lighteval/tasks/templates/math_qa.py
new file mode 100644
index 000000000..5bc591abf
--- /dev/null
+++ b/src/lighteval/tasks/templates/math_qa.py
@@ -0,0 +1,87 @@
+# MIT License
+
+# Copyright (c) 2024 The HuggingFace Team
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import logging
+from typing import Callable
+
+from lighteval.tasks.templates.multichoice import MCQInput, create_adapter_from_dict, get_mcq_prompt_function
+from lighteval.tasks.templates.qa import QAAdapter, QAInput
+from lighteval.tasks.templates.utils.formulation import CFFormulation
+from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
+from lighteval.utils.language import Language
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_math_qa_prompt_function(
+ language: Language, adapter: Callable[[dict], QAInput | None] | QAAdapter, cot: bool = False
+):
+ """
+ Create a templated prompt function for a QA task.
+ Example tasks:
+ - MathQA
+ - GSM8K
+
+ Format:
+ Question: xxx
+ Answer: | Answer
+
+ Args:
+ language (Language): The language of the QA task.
+ adapter (Callable[[dict], QAInput] | QAAdapter): A function or dictionary to adapt the input data to the required QAInput format.
+ Must map data from the dataset row to the QAInput format.
+
+ Returns:
+ Callable: A function that generates QA prompts based on the given parameters.
+ """
+
+ adapter_fn = create_adapter_from_dict(adapter)
+ WARNED_ABOUT_INSTRUCTION = False
+
+ def adapter_for_mcq(line: dict) -> MCQInput | None:
+ input_data = adapter_fn(line)
+ if input_data is None:
+ return None
+
+ choices = input_data["choices"]
+ instruction = input_data.get("instruction", "")
+ if cot and not instruction:
+ translation_literals = TRANSLATION_LITERALS[language]
+ instruction = f"{translation_literals.qa_instruction}\n{translation_literals.math_formatting_instruction}"
+ nonlocal WARNED_ABOUT_INSTRUCTION
+ if not WARNED_ABOUT_INSTRUCTION:
+ logger.warning(
+ f"You are using Math-QA with cot, but did not provide instruction. Default to {instruction}."
+ )
+ WARNED_ABOUT_INSTRUCTION = True
+
+ return {
+ **input_data,
+ "gold_idx": list(range(len(choices))),
+ "instruction": instruction,
+ }
+
+ multichoice_prompt_fn = get_mcq_prompt_function(
+ language, adapter=adapter_for_mcq, formulation=CFFormulation(cot=cot)
+ )
+ return multichoice_prompt_fn
diff --git a/src/lighteval/tasks/templates/multichoice.py b/src/lighteval/tasks/templates/multichoice.py
index 92808488e..2b89205f1 100644
--- a/src/lighteval/tasks/templates/multichoice.py
+++ b/src/lighteval/tasks/templates/multichoice.py
@@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import logging
from typing import Callable
from typing_extensions import NotRequired, TypedDict
@@ -27,12 +28,20 @@
from lighteval.tasks.requests import Doc
from lighteval.tasks.templates.utils.adapter_utils import create_adapter_from_dict
from lighteval.tasks.templates.utils.formatting_utils import capitalize, fix_ending_punct
-from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation, build_answers, build_choices
+from lighteval.tasks.templates.utils.formulation import (
+ CFFormulation,
+ Formulation,
+ MCFFormulation,
+ build_answers,
+ build_choices,
+)
from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
from lighteval.utils.language import Language
from lighteval.utils.utils import as_list
+logger = logging.getLogger(__name__)
+
MULTI_CHOICE_QA_QUERY = (
"{instruction}{context}{question_word}{colon}{sentence_space}{question}\n{options}{answer_word}{colon}"
)
@@ -55,6 +64,7 @@ class MCQInput(TypedDict):
gold_idx: list[int] | int
context: NotRequired[str]
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class MCQDictAdapter(TypedDict):
@@ -73,6 +83,7 @@ class MCQDictAdapter(TypedDict):
gold_idx: str
context: NotRequired[str]
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
# Python too dumb to do fancy inference :(
@@ -122,6 +133,8 @@ def get_mcq_prompt_function(
adapter_fn = create_adapter_from_dict(adapter)
+ WARNED_ABOUT_COT_INSTRUCTION = False
+
def prompt_fn(line, task_name: str):
mcq_input = adapter_fn(line)
if mcq_input is None:
@@ -130,19 +143,42 @@ def prompt_fn(line, task_name: str):
translation_literals = TRANSLATION_LITERALS[language]
instruction_val = mcq_input.get("instruction")
- instruction = f"{instruction_val}\n" if instruction_val else ""
+ if formulation.cot and not instruction_val:
+ match formulation:
+ case MCFFormulation(choice_prefix="Letters") | MCFFormulation(choice_prefix="NativeLetters"):
+ instruction_val = f"{translation_literals.multichoice_mcf_instruction}\n{translation_literals.default_formatting_instruction}"
+ case CFFormulation():
+ instruction_val = (
+ f"{translation_literals.qa_instruction}\n{translation_literals.default_formatting_instruction}"
+ )
+ case _:
+ raise ValueError(
+ "You are using a COT with a unsupported formulation. Either use CF/MCF formulation or provide an instruction."
+ )
+
+ nonlocal WARNED_ABOUT_COT_INSTRUCTION
+ if not WARNED_ABOUT_COT_INSTRUCTION:
+ logger.warning(
+ f"You are using a COT with CF/MCF formulation but did not provide an instruction. Defaulting to {instruction_val}"
+ )
+ WARNED_ABOUT_COT_INSTRUCTION = True
+
+ instruction = f"{instruction_val}\n\n" if instruction_val else ""
context_val = mcq_input.get("context")
context = f"{capitalize(fix_ending_punct(context_val, translation_literals))}\n" if context_val else ""
question = capitalize(fix_ending_punct(mcq_input["question"], translation_literals))
- answers = [capitalize(fix_ending_punct(str(answer), translation_literals)) for answer in mcq_input["choices"]]
+ answers = mcq_input["choices"]
+ gold_idx = mcq_input["gold_idx"]
+ answers = [capitalize(fix_ending_punct(answer, translation_literals)) for answer in answers]
options = build_choices(answers, formulation, translation_literals)
options = f"{options}\n" if options else ""
- answers = build_answers(answers, formulation, translation_literals)
- answer_word = capitalize(translation_literals.answer)
+ answer_word = capitalize(
+ translation_literals.answer if not formulation.cot else translation_literals.answer_cot
+ )
question_word = capitalize(translation_literals.question_word)
query = MULTI_CHOICE_QA_QUERY.format(
@@ -156,10 +192,25 @@ def prompt_fn(line, task_name: str):
options=options,
)
+ # If we are in few-shot mode, then we want to use cot-answers instead of the actual answers
+ # NOTE: it's important to do it after query formatting, because otherwise the options will contain cot
+ is_few_shot = line.get("__few_shots", False)
+ few_shot_cot = mcq_input.get("few_shot_cot", None)
+ if is_few_shot:
+ pass
+ if few_shot_cot and formulation.cot and is_few_shot:
+ answers = [capitalize(fix_ending_punct(answer, translation_literals)) for answer in as_list(few_shot_cot)]
+ gold_idx = list(range(len(answers)))
+ gold_idx = as_list(gold_idx)
+
+ answers = build_answers(
+ answers, formulation, translation_literals, is_few_shot=is_few_shot and bool(few_shot_cot)
+ )
+
return Doc(
task_name=task_name,
query=query,
- gold_index=as_list(mcq_input["gold_idx"]),
+ gold_index=gold_idx,
choices=answers,
instruction=instruction_val,
unconditioned_query=f"{answer_word}{translation_literals.colon}",
diff --git a/src/lighteval/tasks/templates/nli.py b/src/lighteval/tasks/templates/nli.py
index e8809e17b..730f4978a 100644
--- a/src/lighteval/tasks/templates/nli.py
+++ b/src/lighteval/tasks/templates/nli.py
@@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import logging
from typing import Callable, Literal
from typing_extensions import NotRequired, TypedDict
@@ -33,6 +34,9 @@
from lighteval.utils.language import Language
+logger = logging.getLogger(__name__)
+
+
NLI_TEMPLATE_QUERY_CF = "{instruction}{premise}{word_space}{confirmation_word}{question_mark}"
NLI_TEMPLATE_CONT_CF = "{sentence_space}{label}{comma}{word_space}{hypothesis}"
@@ -51,6 +55,7 @@ class NLIInput(TypedDict):
hypothesis: str
gold_idx: int
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class NLIAdapter(TypedDict):
@@ -67,6 +72,7 @@ class NLIAdapter(TypedDict):
hypothesis: str
gold_idx: str
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
RelationType = Literal["entailment", "neutral", "contradiction"]
@@ -201,6 +207,7 @@ def get_nli_prompt_function(
Callable: A function that generates NLI prompts based on the given parameters.
"""
# We use natural implementation for CF formulation to comply with standard evaluation formats
+ WARNED_ABOUT_COT_INSTRUCTION = False
if isinstance(formulation, CFFormulation):
return _nli_prompt_function_natural(language, adapter, relations)
@@ -219,8 +226,15 @@ def get_relation_label(label: RelationType, translation_literals: TranslationLit
# For hybrid we use inlined choices so we use the cf formulation in multichoice prompt fn
mcq_prompt_fn = get_mcq_prompt_function(
language,
- {"context": "premise", "question": "hypothesis", "choices": "choices", "gold_idx": "gold_idx"},
- CFFormulation() if isinstance(formulation, HybridFormulation) else formulation,
+ {
+ "context": "premise",
+ "question": "hypothesis",
+ "choices": "choices",
+ "gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ CFFormulation(cot=formulation.cot) if isinstance(formulation, HybridFormulation) else formulation,
)
def prompt_fn(line: dict, task_name: str):
@@ -234,6 +248,7 @@ def prompt_fn(line: dict, task_name: str):
premise, hypothesis, gold_idx = input_data["premise"], input_data["hypothesis"], input_data["gold_idx"]
premise = fix_ending_punct(capitalize(input_data["premise"]), translation_literals)
hypothesis = input_data["hypothesis"]
+ instruction = input_data.get("instruction", "")
if isinstance(formulation, HybridFormulation):
# If we have the neither option move it to the end to be consistent with standard NLI evaluation
rearranged_labels = labels
@@ -243,15 +258,32 @@ def prompt_fn(line: dict, task_name: str):
choices_str = f"{translation_literals.comma}{translation_literals.word_space}".join(rearranged_labels[:-1])
hypothesis = f"{hypothesis.rstrip(PUNCT)}{translation_literals.sentence_space}{choices_str}{translation_literals.word_space}{translation_literals.or_word}{translation_literals.word_space}{rearranged_labels[-1]}{translation_literals.question_mark}"
+ elif isinstance(formulation, MCFFormulation):
+ if formulation.cot and not instruction:
+ match formulation:
+ case MCFFormulation(choice_prefix="Letters") | MCFFormulation(choice_prefix="NativeLetters"):
+ instruction = f"{translation_literals.nli_mcf_instruction}\n{translation_literals.default_formatting_instruction}"
+ nonlocal WARNED_ABOUT_COT_INSTRUCTION
+ if not WARNED_ABOUT_COT_INSTRUCTION:
+ logger.warning(
+ f"You are using a COT with MCF formulation but did not provide an instruction. Defaulting to {instruction}"
+ )
+ WARNED_ABOUT_COT_INSTRUCTION = True
+ case _:
+ raise ValueError(
+ "You are using a COT with a unsupported formulation. Either use MCF formulation or provide an instruction."
+ )
# (hynky1999): Ideally we would not compute logprobs of the Yes/No/Also in CF formulation. However as of right now lighteval doesn't allow to
# use multi-context.
row = {
- "instruction": input_data.get("instruction", ""),
+ **{x: line[x] for x in line if x.startswith("__")},
+ "instruction": instruction,
"premise": premise,
"hypothesis": hypothesis,
"gold_idx": gold_idx,
"choices": labels,
+ "few_shot_cot": input_data.get("few_shot_cot", ""),
}
return mcq_prompt_fn(row, task_name)
diff --git a/src/lighteval/tasks/templates/qa.py b/src/lighteval/tasks/templates/qa.py
index e798f820d..0a1509e4b 100644
--- a/src/lighteval/tasks/templates/qa.py
+++ b/src/lighteval/tasks/templates/qa.py
@@ -34,6 +34,7 @@ class QAInput(TypedDict):
choices: list[str]
context: NotRequired[str]
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class QAAdapter(TypedDict):
@@ -41,9 +42,12 @@ class QAAdapter(TypedDict):
context: str
context: NotRequired[str]
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
-def get_qa_prompt_function(language: Language, adapter: Callable[[dict], QAInput | None] | QAAdapter):
+def get_qa_prompt_function(
+ language: Language, adapter: Callable[[dict], QAInput | None] | QAAdapter, cot: bool = False
+):
"""
Create a templated prompt function for a QA task.
Example tasks:
@@ -70,13 +74,14 @@ def adapter_for_mcq(line: dict) -> MCQInput | None:
if input_data is None:
return None
- choices = list(set(input_data["choices"]))
+ choices = input_data["choices"]
return {
**input_data,
"gold_idx": list(range(len(choices))),
- "choices": choices,
}
- multichoice_prompt_fn = get_mcq_prompt_function(language, adapter=adapter_for_mcq, formulation=CFFormulation())
+ multichoice_prompt_fn = get_mcq_prompt_function(
+ language, adapter=adapter_for_mcq, formulation=CFFormulation(cot=cot)
+ )
return multichoice_prompt_fn
diff --git a/src/lighteval/tasks/templates/translation.py b/src/lighteval/tasks/templates/translation.py
index c90b99e01..ee9e42058 100644
--- a/src/lighteval/tasks/templates/translation.py
+++ b/src/lighteval/tasks/templates/translation.py
@@ -20,24 +20,29 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
+import logging
from typing import Callable
+from langcodes import Language as LangCodeLanguage
from langcodes import standardize_tag
from typing_extensions import NotRequired, TypedDict
from lighteval.tasks.templates.continuation import get_continuation_prompt_function
from lighteval.tasks.templates.multichoice import create_adapter_from_dict
from lighteval.tasks.templates.utils.formatting_utils import capitalize, fix_ending_punct
-from lighteval.tasks.templates.utils.formulation import Formulation, MCFFormulation
+from lighteval.tasks.templates.utils.formulation import CFFormulation, Formulation, MCFFormulation
from lighteval.tasks.templates.utils.translation_literals import TRANSLATION_LITERALS
from lighteval.utils.language import Language
from lighteval.utils.utils import as_list
+logger = logging.getLogger(__name__)
+
# Template chosen so that it's not very language-dependent, as it's not clear whether one should use the target or source language.
# It's also the best template based on https://arxiv.org/pdf/2301.07069.
+TRANSLATION_INSTRUCTION = "Translate the following text from {source_language} to {target_language}."
TRANSLATION_CONTEXT = "{source_label}{colon}{sentence_space}{source_text}{sentence_space}{target_label}{colon}"
@@ -55,6 +60,7 @@ class TranslationInput(TypedDict):
target_text: str | list[str]
gold_idx: NotRequired[int | list[int]]
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
class TranslationAdapter(TypedDict):
@@ -68,8 +74,9 @@ class TranslationAdapter(TypedDict):
source_text: str
target_text: str
- gold_idx: NotRequired[int | list[int]]
+ gold_idx: NotRequired[str]
instruction: NotRequired[str]
+ few_shot_cot: NotRequired[str]
def get_translation_prompt_function(
@@ -110,7 +117,13 @@ def get_translation_prompt_function(
adapter_fn = create_adapter_from_dict(adapter)
continuation_prompt_fn = get_continuation_prompt_function(
Language.ENGLISH,
- {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"},
+ {
+ "context": "context",
+ "continuations": "continuations",
+ "gold_idx": "gold_idx",
+ "instruction": "instruction",
+ "few_shot_cot": "few_shot_cot",
+ },
formulation,
fix_formatting=False,
)
@@ -119,6 +132,10 @@ def get_translation_prompt_function(
source_label_string = standardize_tag(source_language.value).upper()
target_label_string = standardize_tag(target_language.value).upper()
+ source_language_display_name = LangCodeLanguage.get(source_language.value).display_name()
+ target_language_display_name = LangCodeLanguage.get(target_language.value).display_name()
+
+ WARNED_ABOUT_COT_INSTRUCTION = False
def translation_prompt(
line: dict,
@@ -143,12 +160,41 @@ def translation_prompt(
for text in as_list(input_data["target_text"])
]
+ # Handle instruction
+ instruction_val = input_data.get("instruction")
+ if formulation.cot and not instruction_val:
+ match formulation:
+ case CFFormulation():
+ translation_instruction = TRANSLATION_INSTRUCTION.format(
+ source_language=source_language_display_name, target_language=target_language_display_name
+ )
+ instruction_val = (
+ f"{translation_instruction}\n{source_translation_literals.default_formatting_instruction}"
+ )
+ case MCFFormulation(choice_prefix="NativeLetters") | MCFFormulation(choice_prefix="Letters"):
+ instruction_val = f"{source_translation_literals.multichoice_mcf_instruction}\n{source_translation_literals.default_formatting_instruction}"
+ case _:
+ raise ValueError(
+ "You are using a COT with a unsupported formulation. Either use CF/MCF formulation or provide an instruction."
+ )
+
+ nonlocal WARNED_ABOUT_COT_INSTRUCTION
+ if not WARNED_ABOUT_COT_INSTRUCTION:
+ logger.warning(
+ f" You are using a COT with MCF formulation but did not provide an instruction. Defaulting to {instruction_val}"
+ )
+ WARNED_ABOUT_COT_INSTRUCTION = True
+
+ instruction = f"{instruction_val}" if instruction_val else ""
+
return continuation_prompt_fn(
{
- "instruction": input_data.get("instruction", ""),
+ **{x: line[x] for x in line if x.startswith("__")},
+ "instruction": instruction,
"context": context,
"continuations": continuations,
"gold_idx": input_data.get("gold_idx", list(range(len(continuations)))),
+ "few_shot_cot": input_data.get("few_shot_cot", ""),
},
task_name,
)
diff --git a/src/lighteval/tasks/templates/utils/formulation.py b/src/lighteval/tasks/templates/utils/formulation.py
index 6447e01dc..4cffb54ef 100644
--- a/src/lighteval/tasks/templates/utils/formulation.py
+++ b/src/lighteval/tasks/templates/utils/formulation.py
@@ -39,10 +39,12 @@ class MCFFormulation:
Args:
choice_prefix (ChoicePrefix, optional): The choice prefix to use for the task. Defaults to "Letters".
+ cot (bool, optional): Whether to use COT for the task. Defaults to False.
"""
choice_prefix: ChoicePrefix = "Letters"
name: str = "MCF"
+ cot: bool = False
@dataclass
@@ -58,6 +60,7 @@ class HybridFormulation:
choice_prefix: ChoicePrefix = "Letters"
name: str = "Hybrid"
+ cot: bool = False
@dataclass
@@ -68,6 +71,7 @@ class CFFormulation:
"""
name: str = "CF"
+ cot: bool = False
Formulation = CFFormulation | HybridFormulation | MCFFormulation
@@ -131,6 +135,7 @@ def build_answers(
formulation: Formulation,
translation_literals: TranslationLiterals,
use_sentence_space: bool = True,
+ is_few_shot: bool = False,
) -> list[str]:
"""
Builds a string version of the answers based on passed formulation.
@@ -144,7 +149,7 @@ def build_answers(
use_sentence_space (bool, optional): Whether to use sentence or word space in front of the answer. Defaults to True.
The same value should be passed to `build_choices` function to ensure consistent tokenization.
"""
- if isinstance(formulation, MCFFormulation):
+ if isinstance(formulation, MCFFormulation) and not (formulation.cot and is_few_shot):
prefixes = get_prefix(formulation.choice_prefix, translation_literals)
answers = [prefixes[i] for i in range(len(answers))]
diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py
index 740475cb7..fcbe09538 100644
--- a/src/lighteval/tasks/templates/utils/translation_literals.py
+++ b/src/lighteval/tasks/templates/utils/translation_literals.py
@@ -35,6 +35,8 @@ class TranslationLiterals:
question_word: str = None # type: ignore
answer: str = None # type: ignore
+ answer_cot: str = None # type: ignore
+ options_word: str = "options" # type: ignore
confirmation_word: str = None # type: ignore
yes: str = None # type: ignore
no: str = None # type: ignore
@@ -50,7 +52,7 @@ class TranslationLiterals:
neither: str = None # type: ignore
# Punctuation
- full_stop: str = "."
+ full_stop: str = "." # Separating sequence
comma: str = ","
question_mark: str = "?"
exclamation_mark: str = "!"
@@ -62,6 +64,17 @@ class TranslationLiterals:
# Indices
indices: list[str] = field(default_factory=lambda: LETTER_INDICES)
+ # Instructions
+ continuation_mcf_instruction: str = "Choose the letter of the most likely continuation."
+ nli_mcf_instruction: str = "Choose the letter of the most likely relation between the premise and hypothesis."
+ multichoice_mcf_instruction: str = "Choose the letter of the correct answer."
+ qa_instruction: str = "Answer the following question."
+ # Formatting instruction
+ # This format seems to work quite well across models. We don't put the answer in tags because sometimes models will repeat
+ # the answer resulting in answer ACTUAL ANSWER
+ default_formatting_instruction: str = "Output the final answer in format: ."
+ math_formatting_instruction: str = "Output the answer in \\boxed{}."
+
def __getattribute__(self, name: str) -> str:
value = super().__getattribute__(name)
if value is None:
@@ -84,6 +97,8 @@ def __getattribute__(self, name: str) -> str:
language=Language.ARABIC,
question_word="سؤال",
answer="إجابة",
+ answer_cot="الإجابة خطوة بخطوة",
+ options_word="خيارات",
confirmation_word="صحيح",
yes="نعم",
no="لا",
@@ -103,6 +118,13 @@ def __getattribute__(self, name: str) -> str:
sentence_space=" ",
colon=":",
indices=["أ", "ب", "ج", "د", "هـ", "و", "ز", "ح"],
+ # Translated using gpt4-o
+ continuation_mcf_instruction="اختر الحرف الذي يمثل الاستمرار الأكثر احتمالاً",
+ nli_mcf_instruction="اختر حرف العلاقة الأكثر احتمالا بين المقدمة والفرضية",
+ qa_instruction="أجب عن السؤال التالي",
+ multichoice_mcf_instruction="اختر الحرف الذي يمثل الإجابة الصحيحة",
+ default_formatting_instruction="إخراج الإجابة النهائية بالتنسيق: .",
+ math_formatting_instruction="اكتب الإجابة في \\boxed{}",
),
Language.ARMENIAN: TranslationLiterals(language=Language.ARMENIAN),
Language.ASSAMESE: TranslationLiterals(language=Language.ASSAMESE),
@@ -224,6 +246,8 @@ def __getattribute__(self, name: str) -> str:
language=Language.CHINESE,
question_word="问题",
answer="答案",
+ answer_cot="逐步解答",
+ options_word="选项",
confirmation_word="对吗",
yes="是的",
no="不是",
@@ -243,6 +267,13 @@ def __getattribute__(self, name: str) -> str:
sentence_space="",
colon=":",
indices=["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"],
+ # Translated using gpt4-o
+ continuation_mcf_instruction="选择最可能的继续的字母。",
+ nli_mcf_instruction="选择前提和假设之间最可能的关系的字母。",
+ qa_instruction="回答以下问题。",
+ multichoice_mcf_instruction="选择正确答案的字母。",
+ default_formatting_instruction="以格式输出最终答案。",
+ math_formatting_instruction="在 \\boxed{} 环境中输出答案。",
),
Language.CHOKWE: TranslationLiterals(language=Language.CHOKWE),
Language.CRIMEAN_TATAR: TranslationLiterals(language=Language.CRIMEAN_TATAR),
@@ -327,6 +358,7 @@ def __getattribute__(self, name: str) -> str:
language=Language.ENGLISH,
question_word="question",
answer="answer",
+ answer_cot="Step-by-Step Answer",
confirmation_word="right",
yes="yes",
no="no",
@@ -384,7 +416,9 @@ def __getattribute__(self, name: str) -> str:
language=Language.FRENCH,
question_word="question",
answer="réponse",
+ answer_cot="Réponse étape par étape",
confirmation_word="n'est-ce pas",
+ options_word="options",
yes="oui",
no="non",
also="de plus",
@@ -402,6 +436,13 @@ def __getattribute__(self, name: str) -> str:
word_space=" ",
sentence_space=" ",
colon=":",
+ continuation_mcf_instruction="Choisissez la lettre de la continuation la plus probable.",
+ nli_mcf_instruction="Choisissez la lettre de la relation la plus probable entre la prémisse et l’hypothèse.",
+ qa_instruction="Répondez à la question suivante.",
+ multichoice_mcf_instruction="Choisissez la lettre de la réponse correcte.",
+ # Formatting instruction
+ default_formatting_instruction="Affichez la réponse finale au format: .",
+ math_formatting_instruction="Donnez la réponse dans \\boxed{}.",
),
Language.FRIULIAN: TranslationLiterals(language=Language.FRIULIAN),
Language.GALICIAN: TranslationLiterals(
@@ -491,6 +532,8 @@ def __getattribute__(self, name: str) -> str:
language=Language.HINDI,
question_word="सवाल",
answer="उत्तर",
+ answer_cot="चरण दर चरण उत्तर",
+ options_word="विकल्प",
confirmation_word="है ना",
yes="हाँ",
no="नहीं",
@@ -510,6 +553,13 @@ def __getattribute__(self, name: str) -> str:
sentence_space=" ",
colon=":",
indices=["क", "ख", "ग", "घ", "ङ", "च"],
+ continuation_mcf_instruction="अत्यधिक संभावित निरंतरता का अक्षर चुनें।",
+ nli_mcf_instruction="आधार और परिकल्पना के बीच सबसे संभावित संबंध का अक्षर चुनें।",
+ qa_instruction="निम्नलिखित प्रश्न का उत्तर दें।",
+ multichoice_mcf_instruction="सही उत्तर का अक्षर चुनें।",
+ # Formatting instruction
+ default_formatting_instruction="अंतिम उत्तर को इस प्रारूप में आउटपुट करें: ।",
+ math_formatting_instruction="उत्तर को \\boxed{} में आउटपुट करें।",
),
Language.HUNGARIAN: TranslationLiterals(
language=Language.HUNGARIAN,
@@ -760,6 +810,8 @@ def __getattribute__(self, name: str) -> str:
language=Language.RUSSIAN,
question_word="вопрос",
answer="ответ",
+ answer_cot="Пошаговое решение",
+ options_word="варианты",
confirmation_word="верно",
yes="да",
no="нет",
@@ -778,6 +830,13 @@ def __getattribute__(self, name: str) -> str:
sentence_space=" ",
colon=":",
indices=["А", "Б", "В", "Г", "Д", "Е"],
+ continuation_mcf_instruction="Выберите букву наиболее вероятного продолжения.",
+ nli_mcf_instruction="Выберите букву наиболее вероятной связи между предпосылкой и гипотезой.",
+ qa_instruction="Ответьте на следующий вопрос.",
+ multichoice_mcf_instruction="Выберите букву правильного ответа.",
+ # Formatting instruction
+ default_formatting_instruction="Выведите окончательный ответ в формате: .",
+ math_formatting_instruction="Выведите ответ в \\boxed{}.",
),
Language.SAMOAN: TranslationLiterals(language=Language.SAMOAN),
Language.SANGO: TranslationLiterals(language=Language.SANGO),
@@ -914,7 +973,9 @@ def __getattribute__(self, name: str) -> str:
language=Language.SWAHILI,
question_word="swali",
answer="jibu",
+ answer_cot="jibu la Hatua kwa Hatua",
confirmation_word="sahihi",
+ options_word="chaguo",
yes="ndiyo",
no="hapana",
also="pia",
@@ -931,6 +992,14 @@ def __getattribute__(self, name: str) -> str:
word_space=" ",
sentence_space=" ",
colon=":",
+ # Translated using gpt-4o
+ continuation_mcf_instruction="Chagua herufi ya mwendelezo unaowezekana zaidi.",
+ nli_mcf_instruction="Chagua herufi ya uhusiano unaowezekana kati ya dhana na dhana.",
+ qa_instruction="Jibu swali lifuatalo.",
+ multichoice_mcf_instruction="Chagua herufi ya jibu sahihi.",
+ # Formatting instruction
+ default_formatting_instruction="Toa jibu la mwisho katika umbizo: .",
+ math_formatting_instruction="Toa jibu katika \\boxed{}.",
),
Language.SWATI: TranslationLiterals(language=Language.SWATI),
Language.SWEDISH: TranslationLiterals(
@@ -993,6 +1062,8 @@ def __getattribute__(self, name: str) -> str:
language=Language.TELUGU,
question_word="ప్రశ్న",
answer="జవాబు",
+ answer_cot="దశలవారీగా సమాధానం",
+ options_word="ఎంపికలు",
confirmation_word="కదా",
yes="అవును",
no="కాదు",
@@ -1010,12 +1081,21 @@ def __getattribute__(self, name: str) -> str:
word_space=" ",
sentence_space=" ",
colon=":",
- indices=["ఎ", "బి", "సి", "డి", "ఇ"],
+ indices=["అ", "ఆ", "ఇ", "ఈ", "ఉ", "ఊ"],
+ continuation_mcf_instruction="అత్యంత సాధ్యమైన కొనసాగింపును సూచించే అక్షరాన్ని ఎంచుకోండి.",
+ nli_mcf_instruction="పూర్వాపరం మరియు పరికల్పన మధ్య అత్యంత సంభావ్య సంబంధం యొక్క అక్షరాన్ని ఎంచుకోండి.",
+ qa_instruction="క్రింది ప్రశ్నకు సమాధానం ఇవ్వండి.",
+ multichoice_mcf_instruction="సరైన సమాధానాన్ని సూచించే అక్షరాన్ని ఎంచుకోండి.",
+ # Formatting instruction
+ default_formatting_instruction="తుది సమాధానాన్ని ఈ ఫార్మాట్లో అవుట్పుట్ చేయండి: .",
+ math_formatting_instruction="సమాధానాన్ని \\boxed{} లో ఇవ్వండి.",
),
Language.THAI: TranslationLiterals(
language=Language.THAI,
question_word="คำถาม",
answer="คำตอบ",
+ answer_cot="คำตอบทีละขั้นตอน",
+ options_word="ตัวเลือก",
confirmation_word="ใช่ไหม",
yes="ใช่",
no="ไม่",
@@ -1027,13 +1107,19 @@ def __getattribute__(self, name: str) -> str:
neither="ไม่ใช่ทั้งสองอย่าง",
or_word="หรือ",
full_stop=".",
- comma=",",
question_mark="?",
exclamation_mark="!",
word_space="",
sentence_space=" ",
colon=":",
- indices=["๑", "๒", "๓", "๔", "๕", "๖", "๗", "๘", "๙", "๐"],
+ indices=["ก", "ข", "ค", "ง", "จ", "ฉ", "ช", "ซ"],
+ continuation_mcf_instruction="เลือกตัวอักษรของการดำเนินการต่อที่มีความเป็นไปได้มากที่สุด",
+ nli_mcf_instruction="เลือกตัวอักษรที่มีความสัมพันธ์ที่เป็นไปได้มากที่สุดระหว่างข้อตั้งและสมมติฐาน",
+ qa_instruction="ตอบคำถามต่อไปนี้",
+ multichoice_mcf_instruction="เลือกตัวอักษรของคำตอบที่ถูกต้อง",
+ # Formatting instruction
+ default_formatting_instruction="ส่งออกคำตอบสุดท้ายในรูปแบบ: ",
+ math_formatting_instruction="แสดงคำตอบใน \\boxed{}",
),
Language.TIGRINYA: TranslationLiterals(language=Language.TIGRINYA),
Language.TOK_PISIN: TranslationLiterals(language=Language.TOK_PISIN),
@@ -1046,7 +1132,9 @@ def __getattribute__(self, name: str) -> str:
language=Language.TURKISH,
question_word="soru",
answer="cevap",
+ answer_cot="adım adım cevap",
confirmation_word="değil mi",
+ options_word="seçenekler",
yes="evet",
no="hayır",
also="ayrıca",
@@ -1064,6 +1152,13 @@ def __getattribute__(self, name: str) -> str:
word_space=" ",
sentence_space=" ",
colon=":",
+ continuation_mcf_instruction="En olası devamı temsil eden harfi seçin.",
+ nli_mcf_instruction="Öncül ile hipotez arasındaki olası ilişkinin harfini seçiniz.",
+ qa_instruction="Aşağıdaki soruyu yanıtlayın.",
+ multichoice_mcf_instruction="Doğru cevabı temsil eden harfi seçin.",
+ # Formatting instruction
+ default_formatting_instruction="Son cevabı şu formatta çıktı olarak verin: .",
+ math_formatting_instruction="Cevabı \\boxed{} içinde verin.",
),
Language.TURKMEN: TranslationLiterals(language=Language.TURKMEN),
Language.TWI: TranslationLiterals(language=Language.TWI),
diff --git a/tests/metrics/test_extractive_match.py b/tests/metrics/test_extractive_match.py
index c3a12c813..2b6b5efc6 100644
--- a/tests/metrics/test_extractive_match.py
+++ b/tests/metrics/test_extractive_match.py
@@ -44,7 +44,7 @@ def compare_strings(
gold: str,
pred: str,
language: Language = Language.ENGLISH,
- match_types: list[str] = ["latex", "expr"],
+ match_types: list[str] = ["latex", "expr", "NativeLetters"],
precision: int = 6,
):
"""Helper function to compare strings using the math extraction metrics"""
@@ -56,7 +56,9 @@ def compare_strings(
elif match_type == "expr":
extraction_targets.append(ExprExtractionConfig())
elif match_type == "NativeLetters":
- extraction_targets.append(IndicesExtractionConfig(prefix_for_extraction="NativeLetters"))
+ extraction_targets.append(
+ IndicesExtractionConfig(prefix_for_extraction="NativeLetters", bb_match_priority=0)
+ )
extraction_targets = tuple(extraction_targets) # Convert to tuple
@@ -104,6 +106,7 @@ def test_extraction_abc(gold, pred, expected):
("B", "B。 不是 A", Language.CHINESE, 1),
("B", "B。不是 A", Language.CHINESE, 1),
("B", "B不是 A", Language.CHINESE, 1),
+ ("B", "Hmm I am not sure it's A ?? Not it's B. Surely not D", Language.CHINESE, 1),
],
)
def test_multilingual_extraction_abc(gold, pred, language, expected):
diff --git a/tests/metrics/test_multilingual_metrics.py b/tests/metrics/test_multilingual_metrics.py
new file mode 100644
index 000000000..c7eabe8a4
--- /dev/null
+++ b/tests/metrics/test_multilingual_metrics.py
@@ -0,0 +1,145 @@
+from lighteval.metrics.dynamic_metrics import (
+ multilingual_quasi_exact_match_metric,
+ multilingual_quasi_f1_score_metric,
+)
+from lighteval.utils.language import Language
+
+
+def test_multilingual_quasi_exact_match_happy_path():
+ """Test basic functionality of exact match metric"""
+ metric = multilingual_quasi_exact_match_metric(language=Language.ENGLISH)
+
+ # Test exact match
+ result = metric.sample_level_fn(
+ golds=["hello world"],
+ predictions=["hello world"],
+ )
+ assert result == 1
+
+ # Test with different spacing/punctuation
+ result = metric.sample_level_fn(
+ golds=["hello world"],
+ predictions=["hello, world!"],
+ )
+ assert result == 1
+
+ # Test with no match
+ result = metric.sample_level_fn(
+ golds=["hello world"],
+ predictions=["goodbye world"],
+ )
+ assert result == 0
+
+
+def test_multilingual_quasi_exact_match_bb_extraction():
+ """Test bold text extraction functionality"""
+ metric = multilingual_quasi_exact_match_metric(language=Language.ENGLISH, extract_bb=True)
+
+ # Test with single bold tag
+ result = metric.sample_level_fn(
+ golds=["answer"],
+ predictions=["The correct answer is answer"],
+ )
+ assert result == 1
+
+ # Test with multiple bold tags - should take last one
+ result = metric.sample_level_fn(
+ golds=["final answer"],
+ predictions=["First wrong then final answer"],
+ )
+ assert result == 1
+
+ # Test with no bold tags - should use full text
+ result = metric.sample_level_fn(
+ golds=["answer"],
+ predictions=["answer"],
+ )
+ assert result == 1
+
+ # Test with empty bold tags
+ result = metric.sample_level_fn(
+ golds=["answer"],
+ predictions=[" answer"],
+ )
+ assert result == 0
+
+
+def test_multilingual_quasi_f1_score_happy_path():
+ """Test basic functionality of F1 score metric"""
+ metric = multilingual_quasi_f1_score_metric(language=Language.ENGLISH)
+
+ # Test perfect match
+ result = metric.sample_level_fn(
+ golds=["hello world"],
+ predictions=["hello world"],
+ )
+ assert result == 1
+
+ # Test partial match
+ result = metric.sample_level_fn(
+ golds=["hello beautiful world"],
+ predictions=["hello world"],
+ )
+ assert result > 0 and result < 1
+
+ # Test no match
+ result = metric.sample_level_fn(
+ golds=["hello world"],
+ predictions=["goodbye moon"],
+ )
+ assert result == 0
+
+
+def test_multilingual_quasi_f1_score_bb_extraction():
+ """Test bold text extraction functionality with F1 score"""
+ metric = multilingual_quasi_f1_score_metric(language=Language.ENGLISH, extract_bb=True)
+
+ # Test with single bold tag
+ result = metric.sample_level_fn(
+ golds=["answer key"],
+ predictions=["The correct answer is answer key"],
+ )
+ assert result == 1
+
+ # Test with multiple bold tags - should take last one
+ result = metric.sample_level_fn(
+ golds=["final answer"],
+ predictions=["First wrong then final answer"],
+ )
+ assert result == 1
+
+ # Test with partial match in bold
+ result = metric.sample_level_fn(
+ golds=["complete answer key"],
+ predictions=["The text contains answer key"],
+ )
+ assert result > 0 and result < 1
+
+ # Test with no bold tags - should use full text
+ result = metric.sample_level_fn(
+ golds=["answer"],
+ predictions=["answer"],
+ )
+ assert result == 1
+
+
+def test_multilingual_support():
+ """Test metrics work with different languages"""
+ languages = [Language.ENGLISH, Language.FRENCH, Language.CHINESE]
+
+ for lang in languages:
+ # Test exact match
+ em_metric = multilingual_quasi_exact_match_metric(language=lang)
+ result = em_metric.sample_level_fn(
+ golds=["test"],
+ predictions=["test"],
+ )
+ assert result == 1
+
+ # Test F1 score
+ f1_metric = multilingual_quasi_f1_score_metric(language=lang)
+ result = f1_metric.sample_level_fn(
+ golds=["test"],
+ predictions=["test"],
+ )
+ assert result == 1
diff --git a/tests/tasks/templates/test_continuation.py b/tests/tasks/templates/test_continuation.py
index 9681ecd66..f09a3e1d3 100644
--- a/tests/tasks/templates/test_continuation.py
+++ b/tests/tasks/templates/test_continuation.py
@@ -46,6 +46,8 @@ def test_continuation_prompt_mcf():
doc.query
== """\
The quick brown fox
+
+Options:
A. jumps over the lazy dog
B. runs through the forest
C. chases a rabbit
@@ -131,7 +133,10 @@ def test_continuation_optional_keys():
doc.query
== """\
Choose the most likely continuation:
+
In the morning, I like to
+
+Options:
A. drink coffee
B. go for a run
C. read the news
@@ -142,3 +147,88 @@ def test_continuation_optional_keys():
assert doc.unconditioned_query == "Answer:"
assert doc.choices == [" A", " B", " C"]
assert doc.gold_index == [0]
+
+
+def test_continuation_prompt_mcf_cot():
+ """Test multiple-choice format continuation prompt generation with cot."""
+ test_input = {
+ "context": "The quick brown fox",
+ "continuations": ["jumps over the lazy dog", "Runs through the forest", "Chases a rabbit"],
+ "gold_idx": 0,
+ "__few_shots": True,
+ "few_shot_cot": "i think it's A. jumps over the lazy dog",
+ "instruction": "Choose the letter of the most likely continuation.",
+ }
+
+ prompt_fn = get_continuation_prompt_function(
+ Language.ENGLISH,
+ {
+ "context": "context",
+ "continuations": "continuations",
+ "gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_continuation_task")
+
+ # We expect the contuation to be decapitalized as it's continuation of non-ended sentence
+ assert (
+ doc.query
+ == """\
+Choose the letter of the most likely continuation.
+
+The quick brown fox
+
+Options:
+ A. jumps over the lazy dog
+ B. runs through the forest
+ C. chases a rabbit
+Step-by-Step Answer:\
+"""
+ )
+
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" I think it's A. jumps over the lazy dog"]
+ assert doc.gold_index == [0]
+
+
+def test_continuation_default_instruction_mcf():
+ """Test default instruction for MCF continuation prompt."""
+ test_input = {
+ "context": "The quick brown fox",
+ "continuations": ["jumps over the lazy dog", "Runs through the forest", "Chases a rabbit"],
+ "gold_idx": 0,
+ }
+
+ # Note: "instruction" key is NOT in the key_map
+ prompt_fn = get_continuation_prompt_function(
+ Language.ENGLISH,
+ {"context": "context", "continuations": "continuations", "gold_idx": "gold_idx"},
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_continuation_task_default_mcf")
+
+ expected_instruction = (
+ "Choose the letter of the most likely continuation.\nOutput the final answer in format: ."
+ )
+ assert (
+ doc.query
+ == f"""\
+{expected_instruction}
+
+The quick brown fox
+
+Options:
+ A. jumps over the lazy dog
+ B. runs through the forest
+ C. chases a rabbit
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" A", " B", " C"]
+ assert doc.gold_index == [0]
diff --git a/tests/tasks/templates/test_copa.py b/tests/tasks/templates/test_copa.py
index 775fb37fb..9f7f40ffe 100644
--- a/tests/tasks/templates/test_copa.py
+++ b/tests/tasks/templates/test_copa.py
@@ -23,7 +23,7 @@
import pytest
from lighteval.tasks.templates.copa import get_copa_prompt_function
-from lighteval.tasks.templates.utils.formulation import CFFormulation
+from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
from lighteval.utils.language import Language
@@ -60,3 +60,55 @@ def test_copa_prompt_cf(cause_effect):
assert doc.unconditioned_query == ""
assert doc.choices == [" he has big muscles", " he is weak"]
assert doc.gold_index == [0]
+
+
+@pytest.mark.parametrize("cause_effect", ["cause", "effect"])
+def test_copa_prompt_mcf_cot(cause_effect):
+ """
+ Tests that copa prompt function works correctly for both cause/effect.
+ Since it's pretty much a wrapper around continuation template we just test single formulation.
+
+ """
+ test_input = {
+ "cause_effect": cause_effect,
+ "context": "He is strong",
+ "continuations": ["he has big muscles", "he is weak"],
+ "gold_idx": 0,
+ "__few_shots": True,
+ "few_shot_cot": "i think it's A. he has big muscles",
+ "instruction": "Choose the letter of the most likely continuation.",
+ }
+
+ prompt_fn = get_copa_prompt_function(
+ Language.ENGLISH,
+ {
+ "cause_effect": "cause_effect",
+ "context": "context",
+ "continuations": "continuations",
+ "gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+
+ cause_effect_word = "because" if cause_effect == "cause" else "therefore"
+ assert (
+ doc.query
+ == f"""\
+Choose the letter of the most likely continuation.
+
+He is strong {cause_effect_word}
+
+Options:
+ A. he has big muscles
+ B. he is weak
+Step-by-Step Answer:\
+"""
+ )
+
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" I think it's A. he has big muscles"]
+ assert doc.gold_index == [0]
diff --git a/tests/tasks/templates/test_hellaswag.py b/tests/tasks/templates/test_hellaswag.py
index 2ef7b895b..39b8cb78c 100644
--- a/tests/tasks/templates/test_hellaswag.py
+++ b/tests/tasks/templates/test_hellaswag.py
@@ -91,6 +91,8 @@ def test_hellaswag_prompt_mcf():
doc.query
== """\
Fitness:\nHe is strong he is fast
+
+Options:
A. he has big muscles
B. he is weak
Answer:\
@@ -158,3 +160,102 @@ def test_hellaswag_single_ctx():
doc = prompt_fn(test_input, "test_task")
assert doc.query == "Fitness:\nHe is strong."
+
+
+def test_hellaswag_prompt_mcf_cot():
+ """
+ Tests that hellaswag prompt function works correctly.
+ Since it's pretty much a wrapper around continuation template we just test single formulation.
+
+ """
+ test_input = {
+ "activity_label": "fitness",
+ "ctx_a": "He is strong",
+ "ctx_b": "He is fast",
+ "continuations": ["he has big muscles", "he is weak"],
+ "gold_idx": 0,
+ "__few_shots": True,
+ "few_shot_cot": "i think it's A. he has big muscles",
+ "instruction": "Choose the letter of the most likely continuation.",
+ }
+
+ prompt_fn = get_hellaswag_prompt_function(
+ Language.ENGLISH,
+ {
+ "activity_label": "activity_label",
+ "continuations": "continuations",
+ "gold_idx": "gold_idx",
+ "ctx_a": "ctx_a",
+ "ctx_b": "ctx_b",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+ assert (
+ doc.query
+ == """\
+Choose the letter of the most likely continuation.
+
+Fitness:
+He is strong he is fast
+
+Options:
+ A. he has big muscles
+ B. he is weak
+Step-by-Step Answer:\
+"""
+ )
+
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" I think it's A. he has big muscles"]
+ assert doc.gold_index == [0]
+
+
+def test_hellaswag_default_instruction_mcf():
+ """Test default instruction for MCF hellaswag prompt."""
+ test_input = {
+ "activity_label": "fitness",
+ "ctx_a": "He is strong",
+ "ctx_b": "He is fast",
+ "continuations": ["he has big muscles", "he is weak"],
+ "gold_idx": 0,
+ }
+
+ # Note: "instruction" key is NOT in the key_map
+ prompt_fn = get_hellaswag_prompt_function(
+ Language.ENGLISH,
+ {
+ "activity_label": "activity_label",
+ "continuations": "continuations",
+ "gold_idx": "gold_idx",
+ "ctx_a": "ctx_a",
+ "ctx_b": "ctx_b",
+ },
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task_default_mcf")
+
+ expected_instruction = (
+ "Choose the letter of the most likely continuation.\nOutput the final answer in format: ."
+ )
+ assert (
+ doc.query
+ == f"""\
+{expected_instruction}
+
+Fitness:
+He is strong he is fast
+
+Options:
+ A. he has big muscles
+ B. he is weak
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" A", " B"]
+ assert doc.gold_index == [0]
diff --git a/tests/tasks/templates/test_math_qa.py b/tests/tasks/templates/test_math_qa.py
new file mode 100644
index 000000000..ef0a48766
--- /dev/null
+++ b/tests/tasks/templates/test_math_qa.py
@@ -0,0 +1,66 @@
+# MIT License
+
+# Copyright (c) 2024 The HuggingFace Team
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import logging
+
+from lighteval.tasks.templates.math_qa import get_math_qa_prompt_function
+from lighteval.tasks.templates.qa import QAInput
+from lighteval.utils.language import Language
+
+
+logger = logging.getLogger(__name__)
+
+
+def test_math_qa_prompt_cf_cot_default_instruction():
+ """
+ Tests Math QA with CoT and default instruction.
+ """
+ test_input = {
+ "question": "Solve for x: x + 5 = 10",
+ "choices": ["5"],
+ }
+
+ prompt_fn = get_math_qa_prompt_function(
+ language=Language.ENGLISH,
+ adapter=lambda x: QAInput(
+ question=x["question"],
+ choices=x["choices"],
+ ),
+ cot=True,
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+
+ assert (
+ doc.query
+ == """\
+Answer the following question.
+Output the answer in \\boxed{}.
+
+Question: Solve for x: x + 5 = 10
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" 5"]
+ assert doc.gold_index == [0]
diff --git a/tests/tasks/templates/test_multichoice.py b/tests/tasks/templates/test_multichoice.py
index 9b3614ed2..b0d4cd577 100644
--- a/tests/tasks/templates/test_multichoice.py
+++ b/tests/tasks/templates/test_multichoice.py
@@ -177,6 +177,7 @@ def test_multichoice_optional_keys():
doc.query
== """\
Please answer the following question about geography.
+
France is big.
Question: What is the capital of France?
A. London
@@ -186,3 +187,123 @@ def test_multichoice_optional_keys():
Answer:\
"""
)
+
+
+def test_multichoice_prompt_mcf_cot():
+ """Test multiple-choice format (MCF) with COT prompt generation for multichoice questions."""
+ test_input = {
+ "question": "What is the capital of France?",
+ "choices": ["London", "Paris", "Berlin", "Madrid"],
+ "gold_idx": 1,
+ "instruction": "Please answer the following question about geography.",
+ "__few_shots": True,
+ "few_shot_cot": "i think it's D",
+ }
+
+ prompt_fn = get_mcq_prompt_function(
+ Language.ENGLISH,
+ {
+ "question": "question",
+ "choices": "choices",
+ "gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+ pass
+
+ assert (
+ doc.query
+ == """\
+Please answer the following question about geography.
+
+Question: What is the capital of France?
+ A. London
+ B. Paris
+ C. Berlin
+ D. Madrid
+Step-by-Step Answer:\
+"""
+ )
+
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" I think it's D"]
+
+
+def test_multichoice_default_instruction_mcf():
+ """Test default instruction for MCF multichoice prompt."""
+ test_input = {
+ "question": "What is the capital of France?",
+ "choices": ["London", "Paris", "Berlin", "Madrid"],
+ "gold_idx": 1,
+ }
+
+ # Note: "instruction" key is NOT in the key_map
+ prompt_fn = get_mcq_prompt_function(
+ Language.ENGLISH,
+ {
+ "question": "question",
+ "choices": "choices",
+ "gold_idx": "gold_idx",
+ },
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task_default_mcf")
+
+ # Default instruction from TranslationLiterals.ENGLISH.multichoice_instruction
+ expected_instruction = "Choose the letter of the correct answer.\nOutput the final answer in format: ."
+ expected_query = f"""\
+{expected_instruction}
+
+Question: What is the capital of France?
+ A. London
+ B. Paris
+ C. Berlin
+ D. Madrid
+Step-by-Step Answer:\
+"""
+ assert doc.query == expected_query
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" A", " B", " C", " D"]
+ assert doc.gold_index == [1]
+
+
+def test_multichoice_default_instruction_cf():
+ """Test default instruction for CF multichoice prompt."""
+ test_input = {
+ "question": "What is the capital of France?",
+ "choices": ["London", "Paris", "Berlin", "Madrid"],
+ "gold_idx": 1,
+ }
+
+ # Note: "instruction" key is NOT in the key_map
+ prompt_fn = get_mcq_prompt_function(
+ Language.ENGLISH,
+ {
+ "question": "question",
+ "choices": "choices",
+ "gold_idx": "gold_idx",
+ },
+ CFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task_default_cf")
+
+ # Default instruction from TranslationLiterals.ENGLISH.multichoice_instruction
+ expected_instruction = "Answer the following question.\nOutput the final answer in format: ."
+ assert (
+ doc.query
+ == f"""\
+{expected_instruction}
+
+Question: What is the capital of France?
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" London", " Paris", " Berlin", " Madrid"]
+ assert doc.gold_index == [1]
diff --git a/tests/tasks/templates/test_nli.py b/tests/tasks/templates/test_nli.py
index a634432b9..193f9cce6 100644
--- a/tests/tasks/templates/test_nli.py
+++ b/tests/tasks/templates/test_nli.py
@@ -21,7 +21,7 @@
# SOFTWARE.
from lighteval.tasks.templates.nli import get_nli_prompt_function
-from lighteval.tasks.templates.utils.formulation import CFFormulation, HybridFormulation
+from lighteval.tasks.templates.utils.formulation import CFFormulation, HybridFormulation, MCFFormulation
from lighteval.utils.language import Language
@@ -114,3 +114,85 @@ def test_nli_prompt_hybrid():
assert doc.unconditioned_query == "Answer:"
assert doc.choices == [" True", " Neither", " False"]
assert doc.gold_index == [2]
+
+
+def test_nli_prompt_mcf_cot():
+ """Test multiple-choice format NLI prompt generation with cot."""
+ test_input = {
+ "premise": "The cat is sleeping on the couch.",
+ "hypothesis": "The cat is awake.",
+ "gold_idx": 2,
+ "__few_shots": True,
+ "few_shot_cot": "i think it's A. True",
+ "instruction": "Choose the correct relation between the premise and hypothesis",
+ }
+
+ prompt_fn = get_nli_prompt_function(
+ Language.ENGLISH,
+ {
+ "hypothesis": "hypothesis",
+ "premise": "premise",
+ "gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ ["entailment", "neutral", "contradiction"],
+ formulation=MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_nli_task")
+
+ assert (
+ doc.query
+ == """\
+Choose the correct relation between the premise and hypothesis
+
+The cat is sleeping on the couch.
+Question: The cat is awake.
+ A. True
+ B. Neither
+ C. False
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" I think it's A. True"]
+ assert doc.gold_index == [0]
+
+
+def test_nli_prompt_mcf_cot_default_instruction():
+ """Test multiple-choice format NLI prompt generation with cot."""
+ test_input = {
+ "premise": "The cat is sleeping on the couch.",
+ "hypothesis": "The cat is awake.",
+ "gold_idx": 2,
+ "__few_shots": True,
+ "few_shot_cot": "i think it's A. True",
+ }
+
+ prompt_fn = get_nli_prompt_function(
+ Language.ENGLISH,
+ {"hypothesis": "hypothesis", "premise": "premise", "gold_idx": "gold_idx", "few_shot_cot": "few_shot_cot"},
+ ["entailment", "neutral", "contradiction"],
+ formulation=MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_nli_task")
+
+ assert (
+ doc.query
+ == """\
+Choose the letter of the most likely relation between the premise and hypothesis.
+Output the final answer in format: .
+
+The cat is sleeping on the couch.
+Question: The cat is awake.
+ A. True
+ B. Neither
+ C. False
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" I think it's A. True"]
+ assert doc.gold_index == [0]
diff --git a/tests/tasks/templates/test_qa.py b/tests/tasks/templates/test_qa.py
new file mode 100644
index 000000000..7c9a4dc52
--- /dev/null
+++ b/tests/tasks/templates/test_qa.py
@@ -0,0 +1,135 @@
+# MIT License
+
+# Copyright (c) 2024 The HuggingFace Team
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+from lighteval.tasks.templates.qa import (
+ QAInput, # Reusing QAInput for simplicity
+ get_qa_prompt_function,
+)
+from lighteval.utils.language import Language
+
+
+def test_qa_prompt_cf():
+ """
+ Tests that QA prompt function works correctly for CF formulation.
+ """
+ test_input = {
+ "question": "What is 2 + 2?",
+ "choices": ["4", "5", "6"],
+ }
+
+ prompt_fn = get_qa_prompt_function(
+ language=Language.ENGLISH,
+ adapter=lambda x: QAInput(
+ question=x["question"],
+ choices=x["choices"],
+ ),
+ cot=False,
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+ assert doc is not None
+
+ assert (
+ doc.query
+ == """\
+Question: What is 2 + 2?
+Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Answer:"
+ assert doc.choices == [" 4", " 5", " 6"]
+ assert doc.gold_index == [0, 1, 2]
+
+
+def test_qa_prompt_cf_cot_default_instruction():
+ """
+ Tests QA with CoT and default instruction, checking for the warning.
+ """
+ test_input = {
+ "question": "Solve for x: x + 5 = 10",
+ "choices": ["5", "10", "0"],
+ }
+
+ prompt_fn = get_qa_prompt_function(
+ language=Language.ENGLISH,
+ adapter=lambda x: QAInput(
+ question=x["question"],
+ choices=x["choices"],
+ ),
+ cot=True,
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+
+ assert (
+ doc.query
+ == """\
+Answer the following question.
+Output the final answer in format: .
+
+Question: Solve for x: x + 5 = 10
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" 5", " 10", " 0"]
+ assert doc.gold_index == [0, 1, 2]
+
+
+def test_qa_prompt_cf_cot_custom_instruction():
+ """
+ Tests QA with CoT and custom instruction.
+ """
+ test_input = {
+ "question": "Solve for x: x + 5 = 10",
+ "choices": ["5", "10", "0"],
+ "instruction": "Answer the following question. Output the final answer in format: .",
+ "few_shot_cot": "1+2+3+4+5=15",
+ "__few_shots": True,
+ }
+
+ prompt_fn = get_qa_prompt_function(
+ language=Language.ENGLISH,
+ adapter=lambda x: QAInput(
+ question=x["question"],
+ choices=x["choices"],
+ instruction=x["instruction"],
+ few_shot_cot=x["few_shot_cot"],
+ ),
+ cot=True,
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+
+ assert (
+ doc.query
+ == """\
+Answer the following question. Output the final answer in format: .
+
+Question: Solve for x: x + 5 = 10
+Step-by-Step Answer:\
+"""
+ )
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" 1+2+3+4+5=15"]
+ assert doc.gold_index == [0]
diff --git a/tests/tasks/templates/test_translation.py b/tests/tasks/templates/test_translation.py
index eab59cf18..1089de8ac 100644
--- a/tests/tasks/templates/test_translation.py
+++ b/tests/tasks/templates/test_translation.py
@@ -21,6 +21,8 @@
# SOFTWARE.
+import pytest
+
from lighteval.tasks.templates.translation import get_translation_prompt_function
from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
from lighteval.utils.language import Language
@@ -81,6 +83,8 @@ def test_translation_prompt_mcf():
doc.query
== """\
CS: Ahoj, jak se máš? FR:
+
+Options:
A. Bonjour, comment allez-vous?
B. Ciao, come stai?
Answer:\
@@ -118,3 +122,171 @@ def test_translation_prompt_cf_formatting():
assert doc.unconditioned_query == ""
assert doc.choices == [" 你好吗?"]
assert doc.gold_index == [0]
+
+
+def test_translation_cot_default_instruction():
+ """
+ Tests that translation prompt function uses default instruction when CoT is set to true.
+ """
+ test_input = {
+ "source_text": "How are you?",
+ "target_text": "你好吗?",
+ }
+
+ prompt_fn = get_translation_prompt_function(
+ source_language=Language.ENGLISH,
+ target_language=Language.CHINESE,
+ adapter=lambda x: {
+ "source_text": x["source_text"],
+ "target_text": x["target_text"],
+ },
+ formulation=CFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+ assert doc is not None
+
+ # Check that the default instruction is included
+ expected_instruction = "Translate the following text from English to Chinese.\n"
+ assert doc.query.startswith(expected_instruction)
+ assert "EN: How are you? ZH:" in doc.query
+ assert doc.choices == [" 你好吗?"]
+ assert doc.gold_index == [0]
+
+
+def test_translation_cot_default_instruction_mcf():
+ """
+ Tests that translation prompt function uses default instruction when CoT is set to true for MCF formulation.
+ """
+ test_input = {
+ "source_text": "Ahoj, jak se máš?",
+ "target_text": ["Bonjour, comment allez-vous?", "Ciao, come stai?"],
+ }
+
+ prompt_fn = get_translation_prompt_function(
+ source_language=Language.CZECH,
+ target_language=Language.FRENCH,
+ adapter=lambda x: {
+ "source_text": x["source_text"],
+ "target_text": x["target_text"],
+ "gold_idx": 0,
+ },
+ formulation=MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+ assert doc is not None
+
+ # Check that both default instructions are included
+ expected_instructions = "Choose the letter of the correct answer.\nOutput the final answer in format: .\n\n"
+ assert doc.query.startswith(expected_instructions)
+ assert "CS: Ahoj, jak se máš? FR:" in doc.query
+ assert "A. Bonjour, comment allez-vous?" in doc.query
+ assert "B. Ciao, come stai?" in doc.query
+ assert doc.choices == [" A", " B"]
+ assert doc.gold_index == [0]
+
+
+def test_translation_cot_user_instruction():
+ """
+ Tests that translation prompt function uses user provided instruction when available.
+ """
+ test_input = {
+ "source_text": "How are you?",
+ "target_text": "你好吗?",
+ "instruction": "Please translate this English text to Chinese:",
+ }
+
+ prompt_fn = get_translation_prompt_function(
+ source_language=Language.ENGLISH,
+ target_language=Language.CHINESE,
+ adapter=lambda x: {
+ "source_text": x["source_text"],
+ "target_text": x["target_text"],
+ "instruction": x["instruction"],
+ },
+ formulation=CFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+ assert doc is not None
+
+ # Check that the user instruction is included with formatting instruction
+ expected_instructions = "Please translate this English text to Chinese:\n\n"
+ assert doc.query.startswith(expected_instructions)
+ assert "EN: How are you? ZH:" in doc.query
+ assert doc.choices == [" 你好吗?"]
+ assert doc.gold_index == [0]
+
+
+def test_translation_cot_mcf_number_prefix_error():
+ """
+ Tests that translation prompt function raises an error when using CoT with MCF and Number choice prefix.
+ """
+ test_input = {
+ "source_text": "How are you?",
+ "target_text": "你好吗?",
+ }
+
+ with pytest.raises(ValueError, match="You are using a COT with a unsupported formulation"):
+ prompt_fn = get_translation_prompt_function(
+ source_language=Language.ENGLISH,
+ target_language=Language.CHINESE,
+ adapter=lambda x: {
+ "source_text": x["source_text"],
+ "target_text": x["target_text"],
+ "gold_idx": 0,
+ },
+ formulation=MCFFormulation(cot=True, choice_prefix="Numbers"),
+ )
+
+ prompt_fn(test_input, "test_task")
+
+
+def test_translation_prompt_mcf_cot():
+ """
+ Tests that translation prompt function works correctly for both cause/effect.
+ Since it's pretty much a wrapper around continuation template we just test single formulation.
+
+ """
+ test_input = {
+ "source_text": "How are you?",
+ "target_text": ["你好吗?", "你怎么样?"],
+ "__few_shots": True,
+ "few_shot_cot": "i think it's A.",
+ "gold_idx": 0,
+ "instruction": "Choose the letter of the most likely continuation.",
+ }
+
+ prompt_fn = get_translation_prompt_function(
+ Language.ENGLISH,
+ Language.CHINESE,
+ {
+ "source_text": "source_text",
+ "target_text": "target_text",
+ "gold_idx": "gold_idx",
+ "few_shot_cot": "few_shot_cot",
+ "instruction": "instruction",
+ },
+ MCFFormulation(cot=True),
+ )
+
+ doc = prompt_fn(test_input, "test_task")
+
+ assert (
+ doc.query
+ == """\
+Choose the letter of the most likely continuation.
+
+EN: How are you? ZH:
+
+Options:
+ A. 你好吗?
+ B. 你怎么样?
+Step-by-Step Answer:\
+"""
+ )
+
+ assert doc.unconditioned_query == "Step-by-Step Answer:"
+ assert doc.choices == [" I think it's A."]
+ assert doc.gold_index == [0]