Skip to content

Commit 46a3f0d

Browse files
jyotianejajyotianeja
andauthored
mmlu benchmark (#164)
Added the mmlu benchmark, the HFDataReader only downloads 32/57 catagories, that needs to be fixed. Co-authored-by: jyotianeja <[email protected]>
1 parent f948f57 commit 46a3f0d

File tree

3 files changed

+221
-0
lines changed

3 files changed

+221
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from dataclasses import dataclass
2+
3+
import pandas as pd
4+
5+
from .transform import DFTransformBase
6+
7+
# The list of 57 tasks is taken from `https://github.com/hendrycks/test/blob/master/categories.py`
8+
9+
MMLUCategories = {
10+
"STEM": [
11+
"astronomy",
12+
"college_physics",
13+
"conceptual_physics",
14+
"high_school_physics",
15+
"college_chemistry",
16+
"high_school_chemistry",
17+
"college_biology",
18+
"high_school_biology",
19+
"college_computer_science",
20+
"computer_security",
21+
"high_school_computer_science",
22+
"machine_learning",
23+
"abstract_algebra",
24+
"college_mathematics",
25+
"elementary_mathematics",
26+
"high_school_mathematics",
27+
"high_school_statistics",
28+
"electrical_engineering",
29+
],
30+
"Humanities": [
31+
"high_school_european_history",
32+
"high_school_us_history",
33+
"high_school_world_history",
34+
"prehistory",
35+
"formal_logic",
36+
"logical_fallacies",
37+
"moral_disputes",
38+
"moral_scenarios",
39+
"philosophy",
40+
"world_religions",
41+
"international_law",
42+
"jurisprudence",
43+
"professional_law",
44+
],
45+
"Social Sciences": [
46+
"high_school_government_and_politics",
47+
"public_relations",
48+
"security_studies",
49+
"us_foreign_policy",
50+
"human_sexuality",
51+
"sociology",
52+
"high_school_macroeconomics",
53+
"high_school_microeconomics",
54+
"econometrics",
55+
"high_school_geography",
56+
"high_school_psychology",
57+
"professional_psychology",
58+
],
59+
"Other (Business, Health, Misc.)": [
60+
"global_facts",
61+
"miscellaneous",
62+
"professional_accounting",
63+
"business_ethics",
64+
"management",
65+
"marketing",
66+
"anatomy",
67+
"clinical_knowledge",
68+
"college_medicine",
69+
"human_aging",
70+
"medical_genetics",
71+
"nutrition",
72+
"professional_medicine",
73+
"virology",
74+
],
75+
}
76+
77+
MMLUTaskToCategories = {task: cat for cat, tasks in MMLUCategories.items() for task in tasks}
78+
79+
MMLUAll = [task for cat in MMLUCategories.values() for task in cat]
80+
81+
82+
@dataclass
83+
class CreateMMLUPrompts(DFTransformBase):
84+
"""Transform to create prompts for MMLU dataset."""
85+
def __init__(self):
86+
self.multi_option_example_format = "{}\n{}\nAnswer with the option's letter from the given options directly."
87+
88+
def _create_prompt(self, sample):
89+
question = sample["question"]
90+
options = sample["choices"]
91+
example = ""
92+
start_chr = "A"
93+
index2ans = {}
94+
95+
for option in options:
96+
example += f"({start_chr}) {option}\n"
97+
index2ans[start_chr] = option
98+
start_chr = chr(ord(start_chr) + 1)
99+
100+
prompt = self.multi_option_example_format.format(question, example)
101+
102+
return prompt
103+
104+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
105+
df["prompt"] = df.apply(self._create_prompt, axis=1)
106+
107+
return df

eureka_ml_insights/user_configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .mathverse import MATHVERSE_PIPELINE
6363
from .mathvision import MATHVISION_PIPELINE
6464
from .mathvista import MATHVISTA_PIPELINE
65+
from .mmlu import MMLU_BASELINE_PIPELINE
6566
from .mmmu import MMMU_BASELINE_PIPELINE
6667
from .nocaps import NOCAPS_PIPELINE
6768
from .nondeterminism import (
@@ -132,6 +133,7 @@
132133
GPQA_PIPELINE_5Run,
133134
Drop_Experiment_Pipeline,
134135
GEOMETER_PIPELINE,
136+
MMLU_BASELINE_PIPELINE,
135137
MMMU_BASELINE_PIPELINE,
136138
KITAB_ONE_BOOK_CONSTRAINT_PIPELINE,
137139
KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
from typing import Any
3+
4+
from eureka_ml_insights.configs.experiment_config import ExperimentConfig
5+
from eureka_ml_insights.core import EvalReporting, Inference, PromptProcessing
6+
from eureka_ml_insights.data_utils import (
7+
ASTEvalTransform,
8+
ColumnRename,
9+
CopyColumn,
10+
DataReader,
11+
HFDataReader,
12+
MapStringsTransform,
13+
SequenceTransform,
14+
AddColumnAndData,
15+
SamplerTransform,
16+
)
17+
from eureka_ml_insights.data_utils.mmlu_utils import (
18+
CreateMMLUPrompts,
19+
MMLUAll,
20+
MMLUTaskToCategories,
21+
)
22+
from eureka_ml_insights.metrics import CountAggregator, MMMUMetric
23+
24+
from eureka_ml_insights.data_utils.data import DataLoader
25+
26+
from eureka_ml_insights.configs import(
27+
AggregatorConfig,
28+
DataSetConfig,
29+
EvalReportingConfig,
30+
InferenceConfig,
31+
MetricConfig,
32+
ModelConfig,
33+
PipelineConfig,
34+
PromptProcessingConfig,
35+
)
36+
37+
38+
class MMLU_BASELINE_PIPELINE(ExperimentConfig):
39+
"""
40+
This defines an ExperimentConfig pipeline for the MMLU dataset.
41+
There is no model_config by default and the model config must be passed in via command lime.
42+
"""
43+
44+
def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any] ) -> PipelineConfig:
45+
46+
self.data_processing_comp = PromptProcessingConfig(
47+
component_type=PromptProcessing,
48+
data_reader_config=DataSetConfig(
49+
HFDataReader,
50+
{
51+
"path": "cais/mmlu",
52+
"split": "test",
53+
"tasks": ["abstract_algebra"], #MMLUAll,
54+
"transform": SequenceTransform(
55+
[
56+
# ASTEvalTransform(columns=["choices"]),
57+
CreateMMLUPrompts(),
58+
ColumnRename(name_mapping={"answer": "ground_truth", "choices": "target_options"}),
59+
AddColumnAndData("question_type", "multiple-choice"),
60+
# SamplerTransform(sample_count=10, random_seed=42),
61+
]
62+
),
63+
},
64+
),
65+
output_dir=os.path.join(self.log_dir, "data_processing_output"),
66+
ignore_failure=False,
67+
)
68+
69+
# Configure the inference component
70+
self.inference_comp = InferenceConfig(
71+
component_type=Inference,
72+
model_config=model_config,
73+
data_loader_config=DataSetConfig(
74+
DataLoader,
75+
{"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl")},
76+
),
77+
output_dir=os.path.join(self.log_dir, "inference_result"),
78+
resume_from=resume_from,
79+
)
80+
81+
# Configure the evaluation and reporting component.
82+
self.evalreporting_comp = EvalReportingConfig(
83+
component_type=EvalReporting,
84+
data_reader_config=DataSetConfig(
85+
DataReader,
86+
{
87+
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
88+
"format": ".jsonl",
89+
"transform": SequenceTransform(
90+
[
91+
CopyColumn(column_name_src="__hf_task", column_name_dst="category"),
92+
MapStringsTransform(
93+
columns=["category"],
94+
mapping=MMLUTaskToCategories,
95+
),
96+
]
97+
),
98+
},
99+
),
100+
metric_config=MetricConfig(MMMUMetric),
101+
aggregator_configs=[
102+
AggregatorConfig(CountAggregator, {"column_names": ["MMMUMetric_result"], "normalize": True}),
103+
AggregatorConfig(
104+
CountAggregator,
105+
{"column_names": ["MMMUMetric_result"], "group_by": "category", "normalize": True},
106+
),
107+
],
108+
output_dir=os.path.join(self.log_dir, "eval_report"),
109+
)
110+
111+
# Configure the pipeline
112+
return PipelineConfig([self.data_processing_comp, self.inference_comp, self.evalreporting_comp], self.log_dir)

0 commit comments

Comments
 (0)