Skip to content

Commit 93c8caf

Browse files
committed
Add Phi specific pipelines that filter out COT
1 parent f6cc7e0 commit 93c8caf

File tree

3 files changed

+84
-3
lines changed

3 files changed

+84
-3
lines changed

eureka_ml_insights/data_utils/arc_agi_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,33 @@ def parse_output_answer(response):
3737
answer = response[start_index:end_index].strip()
3838

3939
return answer
40+
41+
42+
@dataclass
43+
class ARCAGI_CleanCOTAnswer(DFTransformBase):
44+
model_output_column: str
45+
model_answer_column: str
46+
47+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
48+
df[self.model_answer_column] = df[self.model_output_column].apply(self.parse_output_answer)
49+
return df
50+
51+
@staticmethod
52+
def parse_output_answer(response):
53+
"""
54+
Replace None responses with an empty string
55+
Parameters:
56+
response (str): Possibly None Response string
57+
Returns:
58+
answer (str): Response string with None replaced by blank string
59+
"""
60+
if response is None:
61+
return ""
62+
63+
start_index = response.find("</think>") + len("</think>")
64+
if start_index == -1:
65+
return response
66+
67+
response = response[start_index:]
68+
69+
return response

eureka_ml_insights/user_configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from .arc_agi import (
99
ARC_AGI_v1_PIPELINE,
1010
ARC_AGI_v1_PIPELINE_5Run,
11+
Phi_ARC_AGI_v1_PIPELINE,
12+
Phi_ARC_AGI_v1_PIPELINE_5Run,
1113
)
1214
from .ba_calendar import (
1315
BA_Calendar_Parallel_PIPELINE,

eureka_ml_insights/user_configs/arc_agi.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from eureka_ml_insights.core.eval_reporting import EvalReporting
77
from eureka_ml_insights.data_utils.arc_agi_utils import (
88
ARCAGI_ExtractAnswer,
9+
ARCAGI_CleanCOTAnswer,
910
)
1011
from eureka_ml_insights.data_utils.data import (
1112
DataLoader,
@@ -91,13 +92,29 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
9192
if resume_logdir:
9293
self.log_dir = resume_from.split("/")[0:len(resume_from.split("/")) - 1]
9394

95+
# Configure the data post processing component.
96+
self.data_post_processing = DataProcessingConfig(
97+
component_type=DataProcessing,
98+
data_reader_config=DataSetConfig(
99+
DataReader,
100+
{
101+
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
102+
"format": ".jsonl",
103+
"transform": SequenceTransform(
104+
[]
105+
),
106+
},
107+
),
108+
output_dir=os.path.join(self.log_dir, "data_post_processing_output"),
109+
)
110+
94111
# Configure the evaluation and reporting component for evaluation and dataset level aggregation
95112
self.evalreporting_comp = EvalReportingConfig(
96113
component_type=EvalReporting,
97114
data_reader_config=DataSetConfig(
98115
DataReader,
99116
{
100-
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
117+
"path": os.path.join(self.data_post_processing.output_dir, "transformed_data.jsonl"),
101118
"format": ".jsonl",
102119
"transform": SequenceTransform(
103120
[
@@ -126,8 +143,6 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
126143
"group_by": "split",
127144
},
128145
),
129-
# the next three reports take the average and std for all repeats
130-
# the resulting numbers are the average and std of N pass@1 scores, where N is number of repeats
131146
AggregatorConfig(
132147
CountAggregator,
133148
{
@@ -206,6 +221,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
206221
[
207222
self.data_processing_comp,
208223
self.inference_comp,
224+
self.data_post_processing,
209225
self.evalreporting_comp,
210226
self.posteval_data_post_processing_comp,
211227
self.best_of_n_evalreporting_comp,
@@ -214,6 +230,25 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
214230
)
215231

216232

233+
class Phi_ARC_AGI_v1_PIPELINE(ARC_AGI_v1_PIPELINE):
234+
def configure_pipeline(self, model_config=None, resume_from=None, **kwargs):
235+
config = super().configure_pipeline(model_config=model_config, resume_from=resume_from)
236+
self.data_post_processing.data_reader_config.init_args["transform"] = SequenceTransform(
237+
[
238+
ColumnRename(
239+
name_mapping={
240+
"model_output": "cot_model_output",
241+
}
242+
),
243+
AddColumn("post_cot_model_output"),
244+
# RunPythonTransform("df['post_cot_model_output'] = df['post_cot_model_output'].apply(lambda x: x.split('</think>')[-1] if '</think>' in x else x)"),
245+
ARCAGI_CleanCOTAnswer("cot_model_output", "post_cot_model_output"),
246+
CopyColumn("post_cot_model_output", "model_output"),
247+
]
248+
)
249+
return config
250+
251+
217252
class ARC_AGI_v1_PIPELINE_5Run(ARC_AGI_v1_PIPELINE):
218253
"""This class specifies the config for running the GPQA benchmark 5 repeated times"""
219254

@@ -226,3 +261,17 @@ def configure_pipeline(
226261
MultiplyTransform(n_repeats=5)
227262
)
228263
return pipeline
264+
265+
266+
class Phi_ARC_AGI_v1_PIPELINE_5Run(ARC_AGI_v1_PIPELINE):
267+
"""This class specifies the config for running the GPQA benchmark 5 repeated times"""
268+
269+
def configure_pipeline(
270+
self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any]
271+
) -> PipelineConfig:
272+
pipeline = super().configure_pipeline(model_config=model_config, resume_from=resume_from)
273+
# data preprocessing
274+
self.data_processing_comp.data_reader_config.init_args["transform"].transforms.append(
275+
MultiplyTransform(n_repeats=5)
276+
)
277+
return pipeline

0 commit comments

Comments
 (0)