Skip to content

Commit 75c9de5

Browse files
committed
Add ARC-AGI-v1 dataset experiment pipeline
1 parent 31d5fce commit 75c9de5

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
You are an intelligent assistant who is very good at answering test questions accurately.
2+
3+
{{ prompt }}

eureka_ml_insights/user_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
AIME_PIPELINE,
66
)
77
from .aime_seq import AIME_SEQ_PIPELINE
8+
from .arc_agi import (
9+
ARC_AGI_v1_PIPELINE,
10+
)
811
from .ba_calendar import (
912
BA_Calendar_Parallel_PIPELINE,
1013
BA_Calendar_PIPELINE,
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
from typing import Any
3+
4+
from eureka_ml_insights.core import Inference, PromptProcessing
5+
from eureka_ml_insights.core.data_processing import DataProcessing
6+
from eureka_ml_insights.core.eval_reporting import EvalReporting
7+
from eureka_ml_insights.data_utils.ba_calendar_utils import (
8+
BA_Calendar_ExtractAnswer,
9+
)
10+
from eureka_ml_insights.data_utils.data import (
11+
DataLoader,
12+
DataReader,
13+
HFDataReader,
14+
)
15+
from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric
16+
from eureka_ml_insights.metrics.reports import (
17+
AverageAggregator,
18+
BiLevelCountAggregator,
19+
BiLevelAggregator,
20+
CountAggregator
21+
)
22+
23+
from eureka_ml_insights.data_utils.transform import (
24+
AddColumn,
25+
AddColumnAndData,
26+
ColumnRename,
27+
CopyColumn,
28+
ExtractUsageTransform,
29+
MajorityVoteTransform,
30+
MultiplyTransform,
31+
RunPythonTransform,
32+
SamplerTransform,
33+
SequenceTransform,
34+
)
35+
from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric
36+
37+
from ..configs.config import (
38+
AggregatorConfig,
39+
DataProcessingConfig,
40+
DataSetConfig,
41+
EvalReportingConfig,
42+
InferenceConfig,
43+
MetricConfig,
44+
ModelConfig,
45+
PipelineConfig,
46+
PromptProcessingConfig,
47+
)
48+
from ..configs.experiment_config import ExperimentConfig
49+
50+
51+
class ARC_AGI_v1_PIPELINE(ExperimentConfig):
52+
"""This class specifies the config for running any benchmark on any model"""
53+
54+
def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=None, **kwargs) -> PipelineConfig:
55+
# data preprocessing
56+
self.data_processing_comp = PromptProcessingConfig(
57+
component_type=PromptProcessing,
58+
prompt_template_path=os.path.join(
59+
os.path.dirname(__file__), "../prompt_templates/arc_agi_templates/arc_agi_v1_basic.jinja"
60+
),
61+
data_reader_config=DataSetConfig(
62+
HFDataReader,
63+
{
64+
"path": "pxferna/ARC-AGI-v1",
65+
"split": "test",
66+
}
67+
),
68+
output_dir=os.path.join(self.log_dir, "data_processing_output"),
69+
)
70+
71+
# inference component
72+
self.inference_comp = InferenceConfig(
73+
component_type=Inference,
74+
model_config=model_config,
75+
data_loader_config=DataSetConfig(
76+
DataLoader,
77+
{"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl")},
78+
),
79+
output_dir=os.path.join(self.log_dir, "inference_result"),
80+
resume_from=resume_from,
81+
max_concurrent=1,
82+
)
83+
84+
if resume_logdir:
85+
self.log_dir = resume_from.split("/")[0:len(resume_from.split("/")) - 1]
86+
87+
# Configure the pipeline
88+
return PipelineConfig(
89+
[
90+
self.data_processing_comp,
91+
self.inference_comp,
92+
],
93+
self.log_dir,
94+
)

0 commit comments

Comments
 (0)