Skip to content

Commit 3982535

Browse files
committed
Update ARC AGI experiment pipeline and add 5x pipeline experiments
1 parent 9806556 commit 3982535

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

eureka_ml_insights/user_configs/arc_agi.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,215 @@ def configure_pipeline(
275275
MultiplyTransform(n_repeats=5)
276276
)
277277
return pipeline
278+
279+
280+
class ARC_AGI_v1_PIPELINE_5050_SUBSET(ExperimentConfig):
281+
"""This class specifies the config for running any benchmark on any model"""
282+
283+
def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=None, **kwargs) -> PipelineConfig:
284+
# data preprocessing
285+
self.data_processing_comp = PromptProcessingConfig(
286+
component_type=PromptProcessing,
287+
prompt_template_path=os.path.join(
288+
os.path.dirname(__file__), "../prompt_templates/arc_agi_templates/arc_agi_v1_grid_explanation.jinja"
289+
),
290+
data_reader_config=DataSetConfig(
291+
HFDataReader,
292+
{
293+
"path": "pxferna/ARC-AGI-v1-5050",
294+
"split": "train",
295+
"transform": SequenceTransform(
296+
[
297+
MultiplyTransform(n_repeats=1),
298+
]
299+
),
300+
}
301+
),
302+
output_dir=os.path.join(self.log_dir, "data_processing_output"),
303+
)
304+
305+
# inference component
306+
self.inference_comp = InferenceConfig(
307+
component_type=Inference,
308+
model_config=model_config,
309+
data_loader_config=DataSetConfig(
310+
DataLoader,
311+
{"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl")},
312+
),
313+
output_dir=os.path.join(self.log_dir, "inference_result"),
314+
resume_from=resume_from,
315+
max_concurrent=16,
316+
)
317+
318+
if resume_logdir:
319+
self.log_dir = resume_from.split("/")[0:len(resume_from.split("/")) - 1]
320+
321+
# Configure the data post processing component.
322+
self.data_post_processing = DataProcessingConfig(
323+
component_type=DataProcessing,
324+
data_reader_config=DataSetConfig(
325+
DataReader,
326+
{
327+
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
328+
"format": ".jsonl",
329+
"transform": SequenceTransform(
330+
[]
331+
),
332+
},
333+
),
334+
output_dir=os.path.join(self.log_dir, "data_post_processing_output"),
335+
)
336+
337+
# Configure the evaluation and reporting component for evaluation and dataset level aggregation
338+
self.evalreporting_comp = EvalReportingConfig(
339+
component_type=EvalReporting,
340+
data_reader_config=DataSetConfig(
341+
DataReader,
342+
{
343+
"path": os.path.join(self.data_post_processing.output_dir, "transformed_data.jsonl"),
344+
"format": ".jsonl",
345+
"transform": SequenceTransform(
346+
[
347+
ExtractUsageTransform(model_config),
348+
ColumnRename(
349+
name_mapping={
350+
"model_output": "raw_output",
351+
}
352+
),
353+
AddColumn("model_output"),
354+
ARCAGI_ExtractAnswer("raw_output", "model_output"),
355+
]
356+
),
357+
},
358+
),
359+
metric_config=MetricConfig(ExactMatch),
360+
aggregator_configs=[
361+
AggregatorConfig(
362+
CountAggregator,
363+
{
364+
"column_names": [
365+
"ExactMatch_result",
366+
],
367+
"filename_base": "OverallMetrics_Separate_Runs_Grouped",
368+
"normalize": True,
369+
"group_by": "split",
370+
},
371+
),
372+
AggregatorConfig(
373+
CountAggregator,
374+
{
375+
"column_names": [
376+
"ExactMatch_result",
377+
],
378+
"normalize": True,
379+
"filename_base": "OverallMetrics_Separate_Runs_Total",
380+
}),
381+
],
382+
output_dir=os.path.join(self.log_dir, "eval_report"),
383+
)
384+
385+
self.posteval_data_post_processing_comp = DataProcessingConfig(
386+
component_type=DataProcessing,
387+
data_reader_config=DataSetConfig(
388+
DataReader,
389+
{
390+
"path": os.path.join(self.evalreporting_comp.output_dir, "metric_results.jsonl"),
391+
"format": ".jsonl",
392+
"transform": SequenceTransform(
393+
[
394+
CopyColumn(
395+
column_name_src="ExactMatch_result",
396+
column_name_dst="ExactMatch_result_numeric",
397+
),
398+
ReplaceStringsTransform(
399+
columns=["ExactMatch_result_numeric"],
400+
mapping={'incorrect': '0', 'correct': '1', 'none': 'NaN'},
401+
case=False)
402+
]
403+
),
404+
},
405+
),
406+
output_dir=os.path.join(self.log_dir, "posteval_data_post_processing_output"),
407+
)
408+
409+
self.best_of_n_evalreporting_comp = EvalReportingConfig(
410+
component_type=EvalReporting,
411+
data_reader_config=DataSetConfig(
412+
DataReader,
413+
{
414+
"path": os.path.join(self.posteval_data_post_processing_comp.output_dir, "transformed_data.jsonl"),
415+
"format": ".jsonl"
416+
},
417+
),
418+
aggregator_configs=[
419+
AggregatorConfig(
420+
BiLevelAggregator,
421+
{
422+
"column_names": [
423+
"ExactMatch_result_numeric",
424+
],
425+
"first_groupby": "uid",
426+
"filename_base": "ExactMatch_Total_BestOfN",
427+
}),
428+
# the first three reports aggregate results by data_point_id and take the best out of N
429+
AggregatorConfig(
430+
BiLevelAggregator,
431+
{
432+
"column_names": [
433+
"ExactMatch_result_numeric"
434+
],
435+
"first_groupby": "uid",
436+
"second_groupby": "split",
437+
"filename_base": "ExactMatch_Grouped_BestOfN",
438+
"agg_fn": "max"
439+
},
440+
),
441+
],
442+
output_dir=os.path.join(self.log_dir, "bestofn_eval_report"),
443+
)
444+
445+
# Configure the pipeline
446+
return PipelineConfig(
447+
[
448+
self.data_processing_comp,
449+
self.inference_comp,
450+
self.data_post_processing,
451+
self.evalreporting_comp,
452+
self.posteval_data_post_processing_comp,
453+
self.best_of_n_evalreporting_comp,
454+
],
455+
self.log_dir,
456+
)
457+
458+
459+
class COT_ARC_AGI_v1_PIPELINE_5050_SUBSET(ARC_AGI_v1_PIPELINE_5050_SUBSET):
460+
def configure_pipeline(self, model_config=None, resume_from=None, **kwargs):
461+
config = super().configure_pipeline(model_config=model_config, resume_from=resume_from)
462+
self.data_post_processing.data_reader_config.init_args["transform"] = SequenceTransform(
463+
[
464+
ColumnRename(
465+
name_mapping={
466+
"model_output": "cot_model_output",
467+
}
468+
),
469+
AddColumn("post_cot_model_output"),
470+
# RunPythonTransform("df['post_cot_model_output'] = df['post_cot_model_output'].apply(lambda x: x.split('</think>')[-1] if '</think>' in x else x)"),
471+
ARCAGI_CleanCOTAnswer("cot_model_output", "post_cot_model_output"),
472+
CopyColumn("post_cot_model_output", "model_output"),
473+
]
474+
)
475+
return config
476+
477+
478+
class COT_ARC_AGI_v1_PIPELINE_5050_SUBSET_5Run(COT_ARC_AGI_v1_PIPELINE_5050_SUBSET):
479+
"""This class specifies the config for running the GPQA benchmark 5 repeated times"""
480+
481+
def configure_pipeline(
482+
self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any]
483+
) -> PipelineConfig:
484+
pipeline = super().configure_pipeline(model_config=model_config, resume_from=resume_from)
485+
# data preprocessing
486+
self.data_processing_comp.data_reader_config.init_args["transform"].transforms.append(
487+
MultiplyTransform(n_repeats=5)
488+
)
489+
return pipeline

0 commit comments

Comments
 (0)