Skip to content

Commit 956bc97

Browse files
committed
Merge remote-tracking branch 'origin/main' into HEAD
2 parents 75383d3 + 441d7a4 commit 956bc97

File tree

6 files changed

+238
-7
lines changed

6 files changed

+238
-7
lines changed

src/lighteval/logging/evaluation_tracker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,13 @@ def push_to_hub(
325325
# We upload it both as a json and a parquet file
326326
result_file_base_name = f"results_{date_id}"
327327
results_json = json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)
328-
self.api.upload_file(
328+
url = self.api.upload_file(
329329
repo_id=repo_id,
330330
path_or_fileobj=BytesIO(results_json.encode("utf-8")),
331331
path_in_repo=f"{result_file_base_name}.json",
332332
repo_type="dataset",
333333
)
334+
logger.info(f"Uploaded evaluation details to {url}")
334335

335336
results_dataset = Dataset.from_dict(
336337
{key: [json.dumps(v, cls=EnhancedJSONEncoder, indent=2)] for key, v in results_dict.items()}

src/lighteval/metrics/metrics.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
import numpy as np
2525
from aenum import Enum
2626

27+
from lighteval.metrics.dynamic_metrics import (
28+
IndicesExtractionConfig,
29+
multilingual_extractive_match_metric,
30+
)
2731
from lighteval.metrics.harness_compatibility.drop import drop_metrics
2832
from lighteval.metrics.harness_compatibility.truthful_qa import truthfulqa_mc_metrics
2933
from lighteval.metrics.metrics_corpus import (
@@ -44,6 +48,7 @@
4448
Faithfulness,
4549
LoglikelihoodAcc,
4650
MajAtK,
51+
PassAtK,
4752
Recall,
4853
StringDistance,
4954
acc_golds_likelihood,
@@ -69,6 +74,7 @@
6974
SampleLevelMetric,
7075
SampleLevelMetricGrouping,
7176
)
77+
from lighteval.utils.language import Language
7278
from lighteval.utils.utils import as_list
7379

7480

@@ -364,6 +370,30 @@ class Metrics(Enum):
364370
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
365371
higher_is_better=True,
366372
)
373+
pass_at_1 = SampleLevelMetric(
374+
metric_name="pass@1:32_samples",
375+
sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute,
376+
category=MetricCategory.GENERATIVE_SAMPLING,
377+
use_case=MetricUseCase.REASONING,
378+
corpus_level_fn=np.mean,
379+
higher_is_better=True,
380+
)
381+
pass_at_10 = SampleLevelMetric(
382+
metric_name="pass@10:32_samples",
383+
sample_level_fn=PassAtK(k=10, n=32, strip_strings=True).compute,
384+
category=MetricCategory.GENERATIVE_SAMPLING,
385+
use_case=MetricUseCase.REASONING,
386+
corpus_level_fn=np.mean,
387+
higher_is_better=True,
388+
)
389+
pass_at_100 = SampleLevelMetric(
390+
metric_name="pass@100:32_samples",
391+
sample_level_fn=PassAtK(k=100, n=32, strip_strings=True).compute,
392+
category=MetricCategory.GENERATIVE_SAMPLING,
393+
use_case=MetricUseCase.REASONING,
394+
corpus_level_fn=np.mean,
395+
higher_is_better=True,
396+
)
367397
perfect_exact_match = SampleLevelMetric(
368398
metric_name="perfect_em",
369399
sample_level_fn=ExactMatches().compute,
@@ -549,6 +579,12 @@ class Metrics(Enum):
549579
corpus_level_fn=CorpusLevelPerplexityMetric("weighted_perplexity").compute,
550580
higher_is_better=False,
551581
)
582+
gpqa_instruct_metric = multilingual_extractive_match_metric(
583+
language=Language.ENGLISH,
584+
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
585+
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
586+
precision=6,
587+
)
552588

553589
def __str__(self):
554590
return self.name.replace("_at_", "@")

src/lighteval/metrics/metrics_sample.py

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import logging
2828
import os
29-
from typing import Callable, Literal
29+
from typing import Callable, Literal, Union
3030

3131
import nltk
3232
import numpy as np
@@ -708,9 +708,21 @@ def __init__(self):
708708
"""Creates a BLEURT scorer using a light bleurt-tiny-512 model.
709709
For more complex use cases, could also be Elron/bleurt-base-128
710710
"""
711-
self.tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-tiny-512")
712-
self.model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512")
713-
self.model.eval()
711+
self._tokenizer = None
712+
self._model = None
713+
714+
@property
715+
def tokenizer(self):
716+
if self._tokenizer is None:
717+
self._tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-tiny-512")
718+
return self._tokenizer
719+
720+
@property
721+
def model(self):
722+
if self._model is None:
723+
self._model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-tiny-512")
724+
self._model.eval()
725+
return self._model
714726

715727
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> float:
716728
"""Uses the stored BLEURT scorer to compute the score on the current sample.
@@ -1043,3 +1055,118 @@ def compute_score(self, pred: str, gold: str) -> int:
10431055
if self.type_exact_match == "suffix":
10441056
return 1 if pred.endswith(gold) else 0
10451057
return 1 if gold == pred else 0
1058+
1059+
1060+
class PassAtK:
1061+
def __init__(
1062+
self,
1063+
k: int,
1064+
n: int = None,
1065+
normalize_gold: Callable = None,
1066+
normalize_pred: Callable = None,
1067+
strip_strings: bool = False,
1068+
sample_scoring_function: Union[Callable[[str, str], float], str] = None,
1069+
):
1070+
"""Computing pass at k
1071+
1072+
Args:
1073+
k (int): Threshold for the number of successful attempts.
1074+
n (int): Number of samples to generate
1075+
normalize_gold (callable, optional): Function to use to normalize the reference strings.
1076+
Defaults to None if no normalization is applied.
1077+
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
1078+
Defaults to None if no normalization is applied.
1079+
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
1080+
sample_scoring_function (callable or str, optional): Function to use to score each sample.
1081+
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1)
1082+
a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full".
1083+
`prefix` checks if the prediction starts with the gold,
1084+
`suffix` if the prediction ends with the gold,
1085+
`full` if the prediction and gold are equal
1086+
"""
1087+
self.k = k
1088+
self.n = n
1089+
self.normalize_gold = normalize_gold
1090+
self.normalize_pred = normalize_pred
1091+
self.strip_strings = strip_strings
1092+
1093+
# Managed the logic of the per prediction of sample scoring
1094+
if callable(sample_scoring_function):
1095+
self.score_sample = sample_scoring_function
1096+
self.type_exact_match = None
1097+
else:
1098+
if isinstance(sample_scoring_function, str):
1099+
if sample_scoring_function not in ["prefix", "suffix", "full"]:
1100+
raise ValueError(
1101+
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
1102+
)
1103+
self.type_exact_match = sample_scoring_function
1104+
else:
1105+
self.type_exact_match = "full"
1106+
self.score_sample = self.default_sample_scoring
1107+
1108+
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
1109+
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
1110+
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
1111+
then aggregates the scores over the samples using a pass@k.
1112+
1113+
Args:
1114+
golds (list[str]): Reference targets
1115+
predictions (list[str]): k predicted strings
1116+
1117+
Returns:
1118+
float: Aggregated score over the current sample's items.
1119+
"""
1120+
if len(golds) > 1:
1121+
raise Exception("Cannot compute pass@k with several golds")
1122+
1123+
if self.n is None:
1124+
self.n = len(predictions)
1125+
logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.")
1126+
elif len(predictions) < self.n:
1127+
logger.warning(f"Number of predictions is less than {self.n} for pass@k.")
1128+
1129+
gold = self.get_processed_gold(golds[0])
1130+
1131+
all_scores = []
1132+
for pred in predictions[: self.n]:
1133+
cur_pred = self.get_processed_pred(pred=pred)
1134+
all_scores.append(self.score_sample(cur_pred, gold))
1135+
1136+
return self.pass_at_k(all_scores)
1137+
1138+
def get_processed_gold(self, gold: str) -> float:
1139+
if self.strip_strings:
1140+
gold = gold.strip()
1141+
1142+
if self.normalize_gold:
1143+
gold = self.normalize_gold(gold)
1144+
1145+
return gold
1146+
1147+
def get_processed_pred(self, pred: str) -> float:
1148+
if not pred:
1149+
return ""
1150+
1151+
if self.strip_strings:
1152+
pred = pred.strip()
1153+
1154+
if self.normalize_pred:
1155+
pred = self.normalize_pred(pred)
1156+
1157+
return pred
1158+
1159+
def default_sample_scoring(self, pred: str, gold: str) -> int:
1160+
if self.type_exact_match == "prefix":
1161+
return 1 if pred.startswith(gold) else 0
1162+
if self.type_exact_match == "suffix":
1163+
return 1 if pred.endswith(gold) else 0
1164+
return 1 if gold == pred else 0
1165+
1166+
def pass_at_k(self, all_scores: list[int]) -> float:
1167+
"""Algo from https://arxiv.org/pdf/2107.03374"""
1168+
c: int = all_scores.count(1)
1169+
if self.n - c < self.k:
1170+
return 1.0
1171+
1172+
return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))

src/lighteval/metrics/utils/extractive_match_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing import Any, Literal, Sequence
2828

2929
import sympy
30-
from sympy import FiniteSet, Number
30+
from sympy import Basic, FiniteSet, MatrixBase, Number
3131
from sympy.parsing import parse_expr
3232

3333
from lighteval.metrics.utils.math_comparison import should_treat_as_complex
@@ -487,7 +487,9 @@ def extract_latex(
487487
return latex_exprs[0], latex_strs[0]
488488

489489

490-
def extract_match(match: re.Match, target_type: ExtractionTarget, timeout_seconds: int):
490+
def extract_match(
491+
match: re.Match, target_type: ExtractionTarget, timeout_seconds: int
492+
) -> tuple[Basic | MatrixBase | str | None, str]:
491493
"""Extracts the match from the regex match.
492494
493495
Args:

src/lighteval/tasks/default_prompts.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,23 @@ def gpqa(line, task_name: str = None):
729729
)
730730

731731

732+
def gpqa_instruct(line, task_name: str = None):
733+
"""Prompt template adapted from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14"""
734+
gold_index = random.randint(0, 3)
735+
choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
736+
choices.insert(gold_index, line["Correct Answer"])
737+
query_template = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}"
738+
query = query_template.format(A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"])
739+
740+
return Doc(
741+
task_name=task_name,
742+
query=query,
743+
choices=LETTER_INDICES[: len(choices)],
744+
gold_index=gold_index,
745+
instruction=query,
746+
)
747+
748+
732749
def gsm8k(line, task_name: str = None):
733750
# Has special analysis in metric for number decomposition
734751
return Doc(

src/lighteval/tasks/default_tasks.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7720,6 +7720,54 @@
77207720
trust_dataset=True,
77217721
version=0,
77227722
)
7723+
gpqa_diamond_instruct_lighteval = LightevalTaskConfig(
7724+
name="gpqa:diamond",
7725+
suite=["lighteval"],
7726+
prompt_function=prompt.gpqa_instruct,
7727+
hf_repo="Idavidrein/gpqa",
7728+
hf_subset="gpqa_diamond",
7729+
hf_avail_splits=["train"],
7730+
evaluation_splits=["train"],
7731+
few_shots_split=None,
7732+
few_shots_select=None,
7733+
generation_size=32768, # needed for reasoning models like R1
7734+
metric=[Metrics.gpqa_instruct_metric],
7735+
stop_sequence=[], # no stop sequence, will use eos token
7736+
trust_dataset=True,
7737+
version=0,
7738+
)
7739+
gpqa_extended_instruct_lighteval = LightevalTaskConfig(
7740+
name="gpqa:extended",
7741+
suite=["lighteval"],
7742+
prompt_function=prompt.gpqa_instruct,
7743+
hf_repo="Idavidrein/gpqa",
7744+
hf_subset="gpqa_extended",
7745+
hf_avail_splits=["train"],
7746+
evaluation_splits=["train"],
7747+
few_shots_split=None,
7748+
few_shots_select=None,
7749+
generation_size=32768, # needed for reasoning models like R1
7750+
metric=[Metrics.gpqa_instruct_metric],
7751+
stop_sequence=[], # no stop sequence, will use eos token
7752+
trust_dataset=True,
7753+
version=0,
7754+
)
7755+
gpqa_main_instruct_lighteval = LightevalTaskConfig(
7756+
name="gpqa:main",
7757+
suite=["lighteval"],
7758+
prompt_function=prompt.gpqa_instruct,
7759+
hf_repo="Idavidrein/gpqa",
7760+
hf_subset="gpqa_main",
7761+
hf_avail_splits=["train"],
7762+
evaluation_splits=["train"],
7763+
few_shots_split=None,
7764+
few_shots_select=None,
7765+
generation_size=32768, # needed for reasoning models like R1
7766+
metric=[Metrics.gpqa_instruct_metric],
7767+
stop_sequence=[], # no stop sequence, will use eos token
7768+
trust_dataset=True,
7769+
version=0,
7770+
)
77237771
gre_reading_comprehension_bigbench = LightevalTaskConfig(
77247772
name="gre_reading_comprehension",
77257773
suite=["bigbench", "bigbench_json"],

0 commit comments

Comments
 (0)