Skip to content

Commit 161d47c

Browse files
authored
Fixing naming for sample evals + adding reqs in aime24 (#989)
* homogeneize k and n in parametrizable metrics * updated aime, last metric fixs * fix * restore rm import * restore * update doc * gpqa fix * pass at * recall * test
1 parent e7d885c commit 161d47c

File tree

6 files changed

+88
-90
lines changed

6 files changed

+88
-90
lines changed

docs/source/package_reference/metrics.mdx

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,21 @@
5656
[[autodoc]] metrics.metrics_sample.BLEU
5757
### StringDistance
5858
[[autodoc]] metrics.metrics_sample.StringDistance
59+
60+
### Metrics allowing sampling
61+
#### PassAtK
62+
[[autodoc]] metrics.metrics_sample.PassAtK
63+
#### MajAtN
64+
[[autodoc]] metrics.metrics_sample.MajAtN
65+
#### AvgAtN
66+
[[autodoc]] metrics.metrics_sample.AvgAtN
67+
68+
## LLM-as-a-Judge
69+
### JudgeLM
70+
[[autodoc]] metrics.utils.llm_as_judge.JudgeLM
5971
### JudgeLLM
6072
[[autodoc]] metrics.metrics_sample.JudgeLLM
6173
### JudgeLLMMTBench
6274
[[autodoc]] metrics.metrics_sample.JudgeLLMMTBench
6375
### JudgeLLMMixEval
6476
[[autodoc]] metrics.metrics_sample.JudgeLLMMixEval
65-
### MajAtK
66-
[[autodoc]] metrics.metrics_sample.MajAtK
67-
68-
## LLM-as-a-Judge
69-
### JudgeLM
70-
[[autodoc]] metrics.utils.llm_as_judge.JudgeLM

src/lighteval/metrics/metrics.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
MRR,
4242
ROUGE,
4343
AccGoldLikelihood,
44-
AvgAtK,
44+
AvgAtN,
4545
BertScore,
4646
ExactMatches,
4747
Extractiveness,
@@ -50,7 +50,7 @@
5050
GPassAtK,
5151
JudgeLLMSimpleQA,
5252
LoglikelihoodAcc,
53-
MajAtK,
53+
MajAtN,
5454
PassAtK,
5555
Recall,
5656
StringDistance,
@@ -85,16 +85,16 @@ class Metrics(Enum):
8585
corpus_level_fn=np.mean,
8686
higher_is_better=True,
8787
)
88-
avg_at_k = SampleLevelMetric(
89-
metric_name="avg@k",
90-
sample_level_fn=AvgAtK(strip_strings=True),
88+
avg_at_n = SampleLevelMetric(
89+
metric_name="avg@n",
90+
sample_level_fn=AvgAtN(strip_strings=True),
9191
category=SamplingMethod.GENERATIVE,
9292
corpus_level_fn=np.mean,
9393
higher_is_better=True,
9494
)
95-
avg_at_k_math = SampleLevelMetric(
96-
metric_name="avg@k",
97-
sample_level_fn=AvgAtK(
95+
avg_at_n_math = SampleLevelMetric(
96+
metric_name="avg@n",
97+
sample_level_fn=AvgAtN(
9898
sample_scoring_function=MultilingualExtractiveMatchMetric(
9999
language=Language.ENGLISH,
100100
gold_extraction_target=[ExprExtractionConfig(), LatexExtractionConfig()],
@@ -365,9 +365,9 @@ class Metrics(Enum):
365365
corpus_level_fn=CorpusLevelF1Score(None),
366366
higher_is_better=True,
367367
)
368-
maj_at_k = SampleLevelMetric(
369-
metric_name="maj@k",
370-
sample_level_fn=MajAtK(),
368+
maj_at_n = SampleLevelMetric(
369+
metric_name="maj@n",
370+
sample_level_fn=MajAtN(),
371371
category=SamplingMethod.GENERATIVE,
372372
corpus_level_fn=np.mean,
373373
higher_is_better=True,

src/lighteval/metrics/metrics_sample.py

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363

6464
class SampleLevelComputation(ABC):
6565
@abstractmethod
66-
def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
66+
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
6767
raise NotImplementedError
6868

6969
def __str__(self):
@@ -444,7 +444,7 @@ def __init__(self, length_normalization: bool = False):
444444
"""
445445
self.length_normalization = length_normalization
446446

447-
def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
447+
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
448448
"""Mean reciprocal rank. Measures the quality of a ranking of choices (ordered by correctness).
449449
450450
Args:
@@ -1129,14 +1129,13 @@ def __init__(
11291129
raise ValueError(f"Unknown normalization function: {normalize}")
11301130
else:
11311131
self.normalize = normalize
1132-
11331132
self.strip_strings = strip_strings
11341133

11351134
if callable(sample_scoring_function):
11361135
self.compute_score = sample_scoring_function
11371136
self.type_exact_match = None
11381137
elif isinstance(sample_scoring_function, SampleLevelComputation):
1139-
self.score_sample = sample_scoring_function.compute
1138+
self.compute_score = sample_scoring_function.compute
11401139
self.type_exact_match = None
11411140
else:
11421141
if isinstance(sample_scoring_function, str):
@@ -1145,11 +1144,9 @@ def __init__(
11451144
f"type_exact_match (used in parametrized_exact_match) must be one of prefix, suffix, or full. Was {sample_scoring_function} instead."
11461145
)
11471146
self.type_exact_match = sample_scoring_function
1148-
self.score_sample = self.default_sample_scoring
11491147
else:
11501148
self.type_exact_match = "full"
11511149
self.compute_score = self.default_sample_scoring
1152-
self.score_sample = self.default_sample_scoring
11531150

11541151
def preprocess(self, text: str) -> str:
11551152
if not text:
@@ -1176,19 +1173,19 @@ def name_metrics(self) -> str | list[str]:
11761173
raise NotImplementedError
11771174

11781175

1179-
class AvgAtK(SamplingMetric, SampleLevelComputation):
1180-
def __init__(self, k: int | None = None, **kwargs):
1181-
"""Sample score averages all the individual k predictions scores.
1176+
class AvgAtN(SamplingMetric, SampleLevelComputation):
1177+
def __init__(self, n: int | None = None, **kwargs):
1178+
"""Sample score averages all the individual n predictions scores.
11821179
11831180
Args:
1184-
k (int | None): The number of top choices to consider.
1181+
n (int | None): Number of samples to generate
11851182
**kwargs: Additional keyword arguments.
11861183
"""
11871184
super().__init__(**kwargs)
1188-
self.k = k
1189-
self.attribute_must_be_set = ["k"]
1185+
self.n = n
1186+
self.attribute_must_be_set = ["n"]
11901187

1191-
def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
1188+
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
11921189
"""Computes the metric over a list of golds and predictions for one single sample.
11931190
It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
11941191
then compares it to the gold.
@@ -1203,36 +1200,32 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs):
12031200
"""
12041201
all_scores = []
12051202
for i in range(self.k):
1206-
all_scores.append(self.score_sample(doc, model_response[i]))
1203+
all_scores.append(self.compute_score(doc, model_response[i]))
12071204

12081205
avg_score = np.mean(all_scores)
12091206
return avg_score
12101207

12111208
def num_samples(self):
1212-
"""Get the number of samples for this metric.
1213-
1214-
Returns:
1215-
int: The number of samples
1216-
"""
1217-
return self.k
1209+
return self.n
12181210

12191211

1220-
class MajAtK(SamplingMetric, SampleLevelComputation):
1221-
def __init__(self, k: int | None = None, **kwargs):
1212+
class MajAtN(SamplingMetric, SampleLevelComputation):
1213+
def __init__(self, n: int | None = None, **kwargs):
12221214
"""An exact match class.
12231215
12241216
Args:
1225-
k (int): The number of top choices to consider.
1217+
n (int): Total number of samples to generate
12261218
**kwargs: Additional keyword arguments.
12271219
"""
12281220
super().__init__(**kwargs)
12291221

1230-
self.k = k
1231-
self.attribute_must_be_set = ["k"]
1222+
self.n = n
1223+
self.attribute_must_be_set = ["n"]
12321224

12331225
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
12341226
"""Computes the metric over a list of golds and predictions for one single sample.
1235-
It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones, then compares it to the gold.
1227+
It applies normalisation (if needed) to model prediction and gold, and takes the most frequent answer of all the available ones,
1228+
then compares it to the gold.
12361229
12371230
Args:
12381231
doc (Doc): The document containing gold references.
@@ -1243,39 +1236,38 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
12431236
float: Aggregated score over the current sample's items.
12441237
"""
12451238
if self.k is None:
1246-
raise Exception("You did not set the value of k")
1239+
raise Exception("You did not set the value of n")
12471240

12481241
golds = doc.get_golds()
1249-
12501242
if len(golds) > 1:
1251-
raise Exception("Cannot compute maj@k with several golds")
1243+
raise Exception("Cannot compute maj@n with several golds")
12521244

12531245
processed_choices = [self.preprocess(text=g) for g in doc.get_golds()]
12541246
new_doc = Doc(
12551247
choices=processed_choices,
12561248
query=doc.query,
1257-
gold_index=list(range(len(processed_choices))),
1249+
gold_index=doc.gold_index,
12581250
)
12591251
all_answers = []
1260-
for pred in model_response.final_text[: self.k]:
1252+
for pred in model_response.final_text[: self.n]:
12611253
all_answers.append(self.preprocess(text=pred))
12621254
majority_prediction = max(all_answers, key=all_answers.count)
12631255
new_model_response = ModelResponse(
12641256
text=[majority_prediction],
12651257
)
1266-
return self.compute_score(new_doc, new_model_response)
1258+
return self.compute_score(new_model_response, new_doc)
12671259

12681260
def num_samples(self):
1269-
return self.k
1261+
return self.n
12701262

12711263

12721264
class PassAtK(SamplingMetric, SampleLevelComputation):
12731265
def __init__(self, k: int | None = None, n: int | None = None, **kwargs):
1274-
"""Computing pass at k
1266+
"""Computing pass at k with an estimator
12751267
12761268
Args:
1277-
k (int | None): Threshold for the number of successful attempts.
1278-
n (int | None): Number of samples to generate.
1269+
k (int | None): Number of correct samples threshold
1270+
n (int | None): Total number of samples to generate.
12791271
**kwargs: Additional keyword arguments.
12801272
"""
12811273
super().__init__(**kwargs)
@@ -1320,7 +1312,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
13201312
new_model_response = ModelResponse(
13211313
text=[cur_pred],
13221314
)
1323-
all_scores.append(self.score_sample(doc=new_doc, model_response=new_model_response))
1315+
all_scores.append(self.compute_score(doc=new_doc, model_response=new_model_response))
13241316

13251317
return self.pass_at_k(all_scores)
13261318

@@ -1348,8 +1340,8 @@ def __init__(
13481340
"""Computing G-Pass@k from http://arxiv.org/abs/2412.13147
13491341
13501342
Args:
1351-
k (Union[int, list[int]] | None): The number of successful attempts to be considered.
1352-
n (int | None): Number of samples to generate.
1343+
k (Union[int, list[int]] | None): Number of correct samples threshold
1344+
n (int | None): Total number of samples to generate.
13531345
thresholds (list[float]): Thresholds to control successful attempts in k generate.
13541346
name_prefix (str | None): Prefix for the metric name.
13551347
**kwargs: Additional keyword arguments.
@@ -1370,7 +1362,7 @@ def k(self):
13701362
def k(self, new_val):
13711363
self._k = as_list(new_val)
13721364

1373-
def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
1365+
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
13741366
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
13751367
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
13761368
then aggregates the scores over the samples using a pass@k.
@@ -1410,7 +1402,7 @@ def compute(self, model_response: ModelResponse, doc: Doc, **kwargs) -> float:
14101402
new_model_response = ModelResponse(
14111403
text=[cur_pred],
14121404
)
1413-
all_scores.append(self.score_sample(new_doc, new_model_response))
1405+
all_scores.append(self.compute_score(new_doc, new_model_response))
14141406

14151407
return self.g_pass_at_k(all_scores)
14161408

src/lighteval/metrics/utils/metric_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(self, sample_params: dict | None):
8989
sample_params_name = "&".join(f"{k}={v}" for k, v in sample_params.items())
9090
if isinstance(self, MetricGrouping):
9191
if hasattr(self.sample_level_fn, "metric_names"):
92-
# this is mostly for the gpass@k metrics
92+
# this is mostly for the gpass@k metrics which redefine submetric names
9393
self.metric_name = self.sample_level_fn.metric_names
9494
else:
9595
self.metric_name = [f"{metric}:{sample_params_name}" for metric in self.metric_name]

0 commit comments

Comments
 (0)