Skip to content

Commit 441d7a4

Browse files
authored
Pass@k (#519)
* init * correct typing * added defaults * small fix
1 parent 15bdbb8 commit 441d7a4

File tree

2 files changed

+141
-1
lines changed

2 files changed

+141
-1
lines changed

src/lighteval/metrics/metrics.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
Faithfulness,
4949
LoglikelihoodAcc,
5050
MajAtK,
51+
PassAtK,
5152
Recall,
5253
StringDistance,
5354
acc_golds_likelihood,
@@ -369,6 +370,30 @@ class Metrics(Enum):
369370
corpus_level_fn=CorpusLevelF1Score(average=None, num_classes=3).compute,
370371
higher_is_better=True,
371372
)
373+
pass_at_1 = SampleLevelMetric(
374+
metric_name="pass@1:32_samples",
375+
sample_level_fn=PassAtK(k=1, n=32, strip_strings=True).compute,
376+
category=MetricCategory.GENERATIVE_SAMPLING,
377+
use_case=MetricUseCase.REASONING,
378+
corpus_level_fn=np.mean,
379+
higher_is_better=True,
380+
)
381+
pass_at_10 = SampleLevelMetric(
382+
metric_name="pass@10:32_samples",
383+
sample_level_fn=PassAtK(k=10, n=32, strip_strings=True).compute,
384+
category=MetricCategory.GENERATIVE_SAMPLING,
385+
use_case=MetricUseCase.REASONING,
386+
corpus_level_fn=np.mean,
387+
higher_is_better=True,
388+
)
389+
pass_at_100 = SampleLevelMetric(
390+
metric_name="pass@100:32_samples",
391+
sample_level_fn=PassAtK(k=100, n=32, strip_strings=True).compute,
392+
category=MetricCategory.GENERATIVE_SAMPLING,
393+
use_case=MetricUseCase.REASONING,
394+
corpus_level_fn=np.mean,
395+
higher_is_better=True,
396+
)
372397
perfect_exact_match = SampleLevelMetric(
373398
metric_name="perfect_em",
374399
sample_level_fn=ExactMatches().compute,

src/lighteval/metrics/metrics_sample.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
import logging
2828
import os
29-
from typing import Callable, Literal
29+
from typing import Callable, Literal, Union
3030

3131
import nltk
3232
import numpy as np
@@ -1055,3 +1055,118 @@ def compute_score(self, pred: str, gold: str) -> int:
10551055
if self.type_exact_match == "suffix":
10561056
return 1 if pred.endswith(gold) else 0
10571057
return 1 if gold == pred else 0
1058+
1059+
1060+
class PassAtK:
1061+
def __init__(
1062+
self,
1063+
k: int,
1064+
n: int = None,
1065+
normalize_gold: Callable = None,
1066+
normalize_pred: Callable = None,
1067+
strip_strings: bool = False,
1068+
sample_scoring_function: Union[Callable[[str, str], float], str] = None,
1069+
):
1070+
"""Computing pass at k
1071+
1072+
Args:
1073+
k (int): Threshold for the number of successful attempts.
1074+
n (int): Number of samples to generate
1075+
normalize_gold (callable, optional): Function to use to normalize the reference strings.
1076+
Defaults to None if no normalization is applied.
1077+
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
1078+
Defaults to None if no normalization is applied.
1079+
strip_strings (bool, optional): Whether to strip both reference and predictions. Defaults to False.
1080+
sample_scoring_function (callable or str, optional): Function to use to score each sample.
1081+
Either pass the full function (should take a string prediction and a string gold, and return a score between 0 and 1)
1082+
a string (any of `prefix`, `suffix` or `full`) to define the type of exact match that you want, or nothing to defaults to "full".
1083+
`prefix` checks if the prediction starts with the gold,
1084+
`suffix` if the prediction ends with the gold,
1085+
`full` if the prediction and gold are equal
1086+
"""
1087+
self.k = k
1088+
self.n = n
1089+
self.normalize_gold = normalize_gold
1090+
self.normalize_pred = normalize_pred
1091+
self.strip_strings = strip_strings
1092+
1093+
# Managed the logic of the per prediction of sample scoring
1094+
if callable(sample_scoring_function):
1095+
self.score_sample = sample_scoring_function
1096+
self.type_exact_match = None
1097+
else:
1098+
if isinstance(sample_scoring_function, str):
1099+
if sample_scoring_function not in ["prefix", "suffix", "full"]:
1100+
raise ValueError(
1101+
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
1102+
)
1103+
self.type_exact_match = sample_scoring_function
1104+
else:
1105+
self.type_exact_match = "full"
1106+
self.score_sample = self.default_sample_scoring
1107+
1108+
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
1109+
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
1110+
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
1111+
then aggregates the scores over the samples using a pass@k.
1112+
1113+
Args:
1114+
golds (list[str]): Reference targets
1115+
predictions (list[str]): k predicted strings
1116+
1117+
Returns:
1118+
float: Aggregated score over the current sample's items.
1119+
"""
1120+
if len(golds) > 1:
1121+
raise Exception("Cannot compute pass@k with several golds")
1122+
1123+
if self.n is None:
1124+
self.n = len(predictions)
1125+
logger.warning("n undefined in the pass@k. We assume it's the same as the sample's number of predictions.")
1126+
elif len(predictions) < self.n:
1127+
logger.warning(f"Number of predictions is less than {self.n} for pass@k.")
1128+
1129+
gold = self.get_processed_gold(golds[0])
1130+
1131+
all_scores = []
1132+
for pred in predictions[: self.n]:
1133+
cur_pred = self.get_processed_pred(pred=pred)
1134+
all_scores.append(self.score_sample(cur_pred, gold))
1135+
1136+
return self.pass_at_k(all_scores)
1137+
1138+
def get_processed_gold(self, gold: str) -> float:
1139+
if self.strip_strings:
1140+
gold = gold.strip()
1141+
1142+
if self.normalize_gold:
1143+
gold = self.normalize_gold(gold)
1144+
1145+
return gold
1146+
1147+
def get_processed_pred(self, pred: str) -> float:
1148+
if not pred:
1149+
return ""
1150+
1151+
if self.strip_strings:
1152+
pred = pred.strip()
1153+
1154+
if self.normalize_pred:
1155+
pred = self.normalize_pred(pred)
1156+
1157+
return pred
1158+
1159+
def default_sample_scoring(self, pred: str, gold: str) -> int:
1160+
if self.type_exact_match == "prefix":
1161+
return 1 if pred.startswith(gold) else 0
1162+
if self.type_exact_match == "suffix":
1163+
return 1 if pred.endswith(gold) else 0
1164+
return 1 if gold == pred else 0
1165+
1166+
def pass_at_k(self, all_scores: list[int]) -> float:
1167+
"""Algo from https://arxiv.org/pdf/2107.03374"""
1168+
c: int = all_scores.count(1)
1169+
if self.n - c < self.k:
1170+
return 1.0
1171+
1172+
return 1.0 - np.prod(1.0 - self.k / np.arange(self.n - c + 1, self.n + 1))

0 commit comments

Comments
 (0)