Skip to content

Commit d50bc30

Browse files
authored
Add pass@1 for GPQA-D and MATH-500 (#698)
* Add pass@1 for GPQA-D and clean up AIME * Add pass@1 for math_500 * Add pass@1 for MATH-500 * Update test * Fix
1 parent 96e885d commit d50bc30

File tree

5 files changed

+113
-13
lines changed

5 files changed

+113
-13
lines changed

examples/custom_tasks_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
few_shots_split=None,
5454
few_shots_select=None,
5555
generation_size=2048,
56-
metric=[Metrics.gpqa_instruct_metric],
56+
metric=[Metrics.gpqa_instruct_pass_at_1_1n],
5757
stop_sequence=[], # no stop sequence, will use eos token
5858
trust_dataset=True,
5959
version=0,

src/lighteval/metrics/metrics.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,38 @@ class Metrics(Enum):
370370
corpus_level_fn=np.mean,
371371
higher_is_better=True,
372372
)
373+
math_pass_at_1_1n = SampleLevelMetric(
374+
metric_name="math_pass@1:1_samples",
375+
sample_level_fn=PassAtK(
376+
k=1,
377+
n=1,
378+
strip_strings=True,
379+
# Extracting mathematical expressions and latex expressions
380+
normalize_gold=lambda k: extract_target_from_pred(
381+
k,
382+
get_extraction_regexes(
383+
formatted_doc=None,
384+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
385+
language=Language.ENGLISH,
386+
),
387+
),
388+
# Extracting mathematical expressions and latex expressions
389+
normalize_pred=lambda k: extract_target_from_pred(
390+
k,
391+
get_extraction_regexes(
392+
formatted_doc=None,
393+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
394+
language=Language.ENGLISH,
395+
),
396+
),
397+
# Uses sympy for comparision
398+
sample_scoring_function=compare_gold_target,
399+
).compute,
400+
category=MetricCategory.GENERATIVE_SAMPLING,
401+
use_case=MetricUseCase.REASONING,
402+
corpus_level_fn=np.mean,
403+
higher_is_better=True,
404+
)
373405
math_pass_at_1_4n = SampleLevelMetric(
374406
metric_name="math_pass@1:4_samples",
375407
sample_level_fn=PassAtK(
@@ -838,6 +870,57 @@ class Metrics(Enum):
838870
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
839871
precision=6,
840872
)
873+
gpqa_instruct_pass_at_1_1n = SampleLevelMetric(
874+
metric_name="gpqa_pass@1:1_samples",
875+
sample_level_fn=PassAtK(
876+
k=1,
877+
n=1,
878+
sample_scoring_function=lambda pred, ref, doc: multilingual_extractive_match_metric(
879+
language=Language.ENGLISH,
880+
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
881+
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
882+
precision=6,
883+
).sample_level_fn([ref], [pred], doc),
884+
).compute,
885+
category=MetricCategory.GENERATIVE_SAMPLING,
886+
use_case=MetricUseCase.REASONING,
887+
corpus_level_fn=np.mean,
888+
higher_is_better=True,
889+
)
890+
gpqa_instruct_pass_at_1_4n = SampleLevelMetric(
891+
metric_name="gpqa_pass@1:4_samples",
892+
sample_level_fn=PassAtK(
893+
k=1,
894+
n=4,
895+
sample_scoring_function=lambda pred, ref, doc: multilingual_extractive_match_metric(
896+
language=Language.ENGLISH,
897+
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
898+
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
899+
precision=6,
900+
).sample_level_fn([ref], [pred], doc),
901+
).compute,
902+
category=MetricCategory.GENERATIVE_SAMPLING,
903+
use_case=MetricUseCase.REASONING,
904+
corpus_level_fn=np.mean,
905+
higher_is_better=True,
906+
)
907+
gpqa_instruct_pass_at_1_8n = SampleLevelMetric(
908+
metric_name="gpqa_pass@1:8_samples",
909+
sample_level_fn=PassAtK(
910+
k=1,
911+
n=8,
912+
sample_scoring_function=lambda pred, ref, doc: multilingual_extractive_match_metric(
913+
language=Language.ENGLISH,
914+
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
915+
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
916+
precision=6,
917+
).sample_level_fn([ref], [pred], doc),
918+
).compute,
919+
category=MetricCategory.GENERATIVE_SAMPLING,
920+
use_case=MetricUseCase.REASONING,
921+
corpus_level_fn=np.mean,
922+
higher_is_better=True,
923+
)
841924

842925
def __str__(self):
843926
return self.name.replace("_at_", "@")

src/lighteval/metrics/metrics_sample.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,9 @@ def __init__(
11631163
self.type_exact_match = "full"
11641164
self.score_sample = self.default_sample_scoring
11651165

1166-
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[str, float]:
1166+
def compute(
1167+
self, golds: list[str], predictions: list[str], formatted_doc: Doc = None, **kwargs
1168+
) -> dict[str, float]:
11671169
"""Computes the metric over a list of golds and predictions for one single item with possibly many samples.
11681170
It applies normalisation (if needed) to model prediction and gold, computes their per prediction score,
11691171
then aggregates the scores over the samples using a pass@k.
@@ -1189,7 +1191,7 @@ def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict[st
11891191
all_scores = []
11901192
for pred in predictions[: self.n]:
11911193
cur_pred = self.get_processed_pred(pred=pred)
1192-
all_scores.append(self.score_sample(cur_pred, gold))
1194+
all_scores.append(self.score_sample(cur_pred, gold, formatted_doc))
11931195

11941196
return self.pass_at_k(all_scores)
11951197

src/lighteval/models/vllm/vllm_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing import Optional
2828

2929
import torch
30-
from pydantic import NonNegativeFloat, PositiveInt
30+
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
3131
from tqdm import tqdm
3232

3333
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
@@ -82,7 +82,7 @@ class VLLMModelConfig(ModelConfig):
8282
gpu_memory_utilization: NonNegativeFloat = 0.9 # lower this if you are running out of memory
8383
max_model_length: PositiveInt | None = None # maximum length of the model, ussually infered automatically. reduce this if you encouter OOM issues, 4096 is usually enough
8484
swap_space: PositiveInt = 4 # CPU swap space size (GiB) per GPU.
85-
seed: PositiveInt = 1234
85+
seed: NonNegativeInt = 1234
8686
trust_remote_code: bool = False
8787
use_chat_template: bool = False
8888
add_special_tokens: bool = True

src/lighteval/tasks/default_tasks.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,14 @@
324324
few_shots_select=None,
325325
generation_size=32768,
326326
metric=[
327-
Metrics.expr_gold_metric,
327+
Metrics.math_pass_at_1_1n,
328+
Metrics.math_pass_at_1_4n,
329+
Metrics.math_pass_at_1_8n,
330+
Metrics.math_pass_at_1_16n,
328331
Metrics.math_pass_at_1_32n,
332+
Metrics.math_pass_at_1_64n,
329333
],
330-
version=1,
334+
version=2,
331335
)
332336
aime24_gpassk = LightevalTaskConfig(
333337
name="aime24_gpassk",
@@ -355,10 +359,14 @@
355359
few_shots_select=None,
356360
generation_size=10000,
357361
metric=[
358-
Metrics.expr_gold_metric,
362+
Metrics.math_pass_at_1_1n,
363+
Metrics.math_pass_at_1_4n,
364+
Metrics.math_pass_at_1_8n,
365+
Metrics.math_pass_at_1_16n,
359366
Metrics.math_pass_at_1_32n,
367+
Metrics.math_pass_at_1_64n,
360368
],
361-
version=1,
369+
version=2,
362370
)
363371
aime25_gpassk = LightevalTaskConfig(
364372
name="aime25_gpassk",
@@ -7809,10 +7817,14 @@
78097817
few_shots_split=None,
78107818
few_shots_select=None,
78117819
generation_size=32768, # needed for reasoning models like R1
7812-
metric=[Metrics.gpqa_instruct_metric],
7820+
metric=[
7821+
Metrics.gpqa_instruct_pass_at_1_1n,
7822+
Metrics.gpqa_instruct_pass_at_1_4n,
7823+
Metrics.gpqa_instruct_pass_at_1_8n,
7824+
],
78137825
stop_sequence=[], # no stop sequence, will use eos token
78147826
trust_dataset=True,
7815-
version=0,
7827+
version=1,
78167828
)
78177829
gpqa_extended_instruct_lighteval = LightevalTaskConfig(
78187830
name="gpqa:extended",
@@ -9688,8 +9700,11 @@
96889700
few_shots_split=None,
96899701
few_shots_select=None,
96909702
generation_size=32768,
9691-
metric=[Metrics.latex_gold_metric],
9692-
version=1,
9703+
metric=[
9704+
Metrics.math_pass_at_1_1n,
9705+
Metrics.math_pass_at_1_4n,
9706+
],
9707+
version=2,
96939708
)
96949709
math_500_gpassk = LightevalTaskConfig(
96959710
name="math_500_gpassk",

0 commit comments

Comments
 (0)