Skip to content

Commit f02ad90

Browse files
committed
Add metrics and aggregators for ARC AGI pipeline
1 parent 75c9de5 commit f02ad90

File tree

2 files changed

+94
-3
lines changed

2 files changed

+94
-3
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import re
2+
from dataclasses import dataclass
3+
4+
import pandas as pd
5+
6+
from .transform import DFTransformBase
7+
8+
9+
@dataclass
10+
class ARCAGI_ExtractAnswer(DFTransformBase):
11+
model_output_column: str
12+
model_answer_column: str
13+
14+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
15+
df[self.model_answer_column] = df[self.model_output_column].apply(self.parse_output_answer)
16+
return df
17+
18+
@staticmethod
19+
def parse_output_answer(response):
20+
"""
21+
Parse the input string to extract answer of a given ARCAGI question.
22+
Parameters:
23+
response (str): Input string containing answer X in the form of "<output>final answer string</output>".
24+
Returns:
25+
answer (str): The final answer string with leading and training spaces stripped.
26+
"""
27+
answer = ""
28+
29+
if response is None:
30+
return ""
31+
elif response.find("<output>") == -1 or response.find("</output>") == -1:
32+
return ""
33+
34+
start_index = response.find("<output>") + len("<output>")
35+
end_index = response.find("</output>")
36+
37+
answer = response[start_index:end_index].strip()
38+
39+
return answer

eureka_ml_insights/user_configs/arc_agi.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
from eureka_ml_insights.core import Inference, PromptProcessing
55
from eureka_ml_insights.core.data_processing import DataProcessing
66
from eureka_ml_insights.core.eval_reporting import EvalReporting
7-
from eureka_ml_insights.data_utils.ba_calendar_utils import (
8-
BA_Calendar_ExtractAnswer,
7+
from eureka_ml_insights.data_utils.arc_agi_utils import (
8+
ARCAGI_ExtractAnswer,
99
)
1010
from eureka_ml_insights.data_utils.data import (
1111
DataLoader,
1212
DataReader,
1313
HFDataReader,
1414
)
15-
from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric
15+
from eureka_ml_insights.metrics.metrics_base import ExactMatch
1616
from eureka_ml_insights.metrics.reports import (
17+
CountAggregator,
1718
AverageAggregator,
1819
BiLevelCountAggregator,
1920
BiLevelAggregator,
@@ -84,11 +85,62 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
8485
if resume_logdir:
8586
self.log_dir = resume_from.split("/")[0:len(resume_from.split("/")) - 1]
8687

88+
# Configure the evaluation and reporting component for evaluation and dataset level aggregation
89+
self.evalreporting_comp = EvalReportingConfig(
90+
component_type=EvalReporting,
91+
data_reader_config=DataSetConfig(
92+
DataReader,
93+
{
94+
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
95+
"format": ".jsonl",
96+
"transform": SequenceTransform(
97+
[
98+
ExtractUsageTransform(model_config),
99+
ColumnRename(
100+
name_mapping={
101+
"model_output": "raw_output",
102+
}
103+
),
104+
AddColumn("model_output"),
105+
ARCAGI_ExtractAnswer("raw_output", "model_output"),
106+
]
107+
),
108+
},
109+
),
110+
metric_config=MetricConfig(ExactMatch),
111+
aggregator_configs=[
112+
AggregatorConfig(
113+
CountAggregator,
114+
{
115+
"column_names": [
116+
"ExactMatch_result",
117+
],
118+
"filename_base": "OverallMetrics_Separate_Runs_Grouped",
119+
"normalize": True,
120+
"group_by": "split",
121+
},
122+
),
123+
# the next three reports take the average and std for all repeats
124+
# the resulting numbers are the average and std of N pass@1 scores, where N is number of repeats
125+
AggregatorConfig(
126+
CountAggregator,
127+
{
128+
"column_names": [
129+
"ExactMatch_result",
130+
],
131+
"normalize": True,
132+
"filename_base": "OverallMetrics_Separate_Runs_Total",
133+
}),
134+
],
135+
output_dir=os.path.join(self.log_dir, "eval_report"),
136+
)
137+
87138
# Configure the pipeline
88139
return PipelineConfig(
89140
[
90141
self.data_processing_comp,
91142
self.inference_comp,
143+
self.evalreporting_comp,
92144
],
93145
self.log_dir,
94146
)

0 commit comments

Comments
 (0)