66from eureka_ml_insights .core .eval_reporting import EvalReporting
77from eureka_ml_insights .data_utils .arc_agi_utils import (
88 ARCAGI_ExtractAnswer ,
9+ ARCAGI_CleanCOTAnswer ,
910)
1011from eureka_ml_insights .data_utils .data import (
1112 DataLoader ,
@@ -91,13 +92,29 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
9192 if resume_logdir :
9293 self .log_dir = resume_from .split ("/" )[0 :len (resume_from .split ("/" )) - 1 ]
9394
95+ # Configure the data post processing component.
96+ self .data_post_processing = DataProcessingConfig (
97+ component_type = DataProcessing ,
98+ data_reader_config = DataSetConfig (
99+ DataReader ,
100+ {
101+ "path" : os .path .join (self .inference_comp .output_dir , "inference_result.jsonl" ),
102+ "format" : ".jsonl" ,
103+ "transform" : SequenceTransform (
104+ []
105+ ),
106+ },
107+ ),
108+ output_dir = os .path .join (self .log_dir , "data_post_processing_output" ),
109+ )
110+
94111 # Configure the evaluation and reporting component for evaluation and dataset level aggregation
95112 self .evalreporting_comp = EvalReportingConfig (
96113 component_type = EvalReporting ,
97114 data_reader_config = DataSetConfig (
98115 DataReader ,
99116 {
100- "path" : os .path .join (self .inference_comp .output_dir , "inference_result .jsonl" ),
117+ "path" : os .path .join (self .data_post_processing .output_dir , "transformed_data .jsonl" ),
101118 "format" : ".jsonl" ,
102119 "transform" : SequenceTransform (
103120 [
@@ -126,8 +143,6 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
126143 "group_by" : "split" ,
127144 },
128145 ),
129- # the next three reports take the average and std for all repeats
130- # the resulting numbers are the average and std of N pass@1 scores, where N is number of repeats
131146 AggregatorConfig (
132147 CountAggregator ,
133148 {
@@ -206,6 +221,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
206221 [
207222 self .data_processing_comp ,
208223 self .inference_comp ,
224+ self .data_post_processing ,
209225 self .evalreporting_comp ,
210226 self .posteval_data_post_processing_comp ,
211227 self .best_of_n_evalreporting_comp ,
@@ -214,6 +230,25 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
214230 )
215231
216232
233+ class Phi_ARC_AGI_v1_PIPELINE (ARC_AGI_v1_PIPELINE ):
234+ def configure_pipeline (self , model_config = None , resume_from = None , ** kwargs ):
235+ config = super ().configure_pipeline (model_config = model_config , resume_from = resume_from )
236+ self .data_post_processing .data_reader_config .init_args ["transform" ] = SequenceTransform (
237+ [
238+ ColumnRename (
239+ name_mapping = {
240+ "model_output" : "cot_model_output" ,
241+ }
242+ ),
243+ AddColumn ("post_cot_model_output" ),
244+ # RunPythonTransform("df['post_cot_model_output'] = df['post_cot_model_output'].apply(lambda x: x.split('</think>')[-1] if '</think>' in x else x)"),
245+ ARCAGI_CleanCOTAnswer ("cot_model_output" , "post_cot_model_output" ),
246+ CopyColumn ("post_cot_model_output" , "model_output" ),
247+ ]
248+ )
249+ return config
250+
251+
217252class ARC_AGI_v1_PIPELINE_5Run (ARC_AGI_v1_PIPELINE ):
218253 """This class specifies the config for running the GPQA benchmark 5 repeated times"""
219254
@@ -226,3 +261,17 @@ def configure_pipeline(
226261 MultiplyTransform (n_repeats = 5 )
227262 )
228263 return pipeline
264+
265+
266+ class Phi_ARC_AGI_v1_PIPELINE_5Run (ARC_AGI_v1_PIPELINE ):
267+ """This class specifies the config for running the GPQA benchmark 5 repeated times"""
268+
269+ def configure_pipeline (
270+ self , model_config : ModelConfig , resume_from : str = None , ** kwargs : dict [str , Any ]
271+ ) -> PipelineConfig :
272+ pipeline = super ().configure_pipeline (model_config = model_config , resume_from = resume_from )
273+ # data preprocessing
274+ self .data_processing_comp .data_reader_config .init_args ["transform" ].transforms .append (
275+ MultiplyTransform (n_repeats = 5 )
276+ )
277+ return pipeline
0 commit comments