Skip to content

Commit a5fcc6f

Browse files
committed
Make think tag name a parameter and fix unused imports
1 parent f965777 commit a5fcc6f

File tree

5 files changed

+40
-60
lines changed

5 files changed

+40
-60
lines changed

eureka_ml_insights/data_utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
AddColumn,
2424
AddColumnAndData,
2525
ASTEvalTransform,
26-
CleanCOTAnswer,
2726
ColumnMatchMapTransform,
2827
ColumnRename,
2928
CopyColumn,

eureka_ml_insights/data_utils/arc_agi_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
from dataclasses import dataclass
32

43
import pandas as pd
@@ -21,8 +20,8 @@ def parse_output_answer(response):
2120
Parse the input string to extract answer of a given ARCAGI question.
2221
Parameters:
2322
response (str): Input string containing answer X in the form of "<output>final answer string</output>".
24-
Returns:
25-
answer (str): The final answer string with leading and training spaces stripped.
23+
Returns:
24+
answer (str): The final answer string with leading and trailing spaces stripped.
2625
"""
2726
answer = ""
2827

eureka_ml_insights/data_utils/transform.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -549,11 +549,12 @@ def _extract_usage(self, row, usage_completion_read_col):
549549
@dataclass
550550
class CleanCOTAnswer(DFTransformBase):
551551
"""
552-
Transform to strip out anything before and including the </think> tag in the model response
552+
Transform to strip out anything before and including the </think_tag_name> tag in the model response
553553
"""
554554

555555
model_output_column: str
556556
model_answer_column: str
557+
think_tag_name: str = "think"
557558

558559
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
559560
df[self.model_answer_column] = df[self.model_output_column].apply(self.parse_output_answer)
@@ -562,21 +563,22 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
562563
@staticmethod
563564
def parse_output_answer(response):
564565
"""
565-
Replace None responses with an empty string
566+
Possibly null response string with chain of thought wrapped in <think_tag_name> and </think_tag_name> tags
566567
Parameters:
567-
response (str): Possibly None Response string
568-
Returns:
569-
answer (str): Response string with None replaced by blank string
568+
response (str): Possibly null response string with chain of thought wrapped in
569+
<think_tag_name> and </think_tag_name> tags.
570+
Returns:
571+
answer (str): Response string with None replaced by blank string and the chain of thought stripped out.
570572
"""
571573
if response is None:
572574
return ""
573-
574-
start_index = response.find("</think>")
575+
576+
start_index = response.find(f"</{self.think_tag_name}>")
575577
if start_index == -1:
576578
return response
577-
578-
start_index = start_index + len("</think>")
579-
579+
580+
start_index = start_index + len(f"</{self.think_tag_name}>")
581+
580582
response = response[start_index:]
581583

582584
return response

eureka_ml_insights/user_configs/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,6 @@
55
AIME_PIPELINE,
66
)
77
from .aime_seq import AIME_SEQ_PIPELINE
8-
from .arc_agi import (
9-
ARC_AGI_v1_PIPELINE,
10-
ARC_AGI_v1_PIPELINE_5Run,
11-
COT_ARC_AGI_v1_PIPELINE,
12-
COT_ARC_AGI_v1_PIPELINE_5Run,
13-
COT_ARC_AGI_v1_PIPELINE_5050_SUBSET,
14-
COT_ARC_AGI_v1_PIPELINE_5050_SUBSET_5Run,
15-
)
168
from .ba_calendar import (
179
BA_Calendar_Parallel_PIPELINE,
1810
BA_Calendar_PIPELINE,

eureka_ml_insights/user_configs/arc_agi.py

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,27 @@
44
from eureka_ml_insights.core import Inference, PromptProcessing
55
from eureka_ml_insights.core.data_processing import DataProcessing
66
from eureka_ml_insights.core.eval_reporting import EvalReporting
7-
from eureka_ml_insights.data_utils.arc_agi_utils import (
8-
ARCAGI_ExtractAnswer
9-
)
7+
from eureka_ml_insights.data_utils.arc_agi_utils import ARCAGI_ExtractAnswer
108
from eureka_ml_insights.data_utils.data import (
119
DataLoader,
1210
DataReader,
1311
HFDataReader,
1412
)
15-
from eureka_ml_insights.metrics.metrics_base import ExactMatch
16-
from eureka_ml_insights.metrics.reports import (
17-
CountAggregator,
18-
AverageAggregator,
19-
BiLevelCountAggregator,
20-
BiLevelAggregator,
21-
CountAggregator
22-
)
23-
2413
from eureka_ml_insights.data_utils.transform import (
2514
AddColumn,
26-
AddColumnAndData,
2715
CleanCOTAnswer,
2816
ColumnRename,
2917
CopyColumn,
3018
ExtractUsageTransform,
31-
MajorityVoteTransform,
3219
MultiplyTransform,
3320
ReplaceStringsTransform,
34-
RunPythonTransform,
35-
SamplerTransform,
3621
SequenceTransform,
3722
)
38-
from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric
23+
from eureka_ml_insights.metrics.metrics_base import ExactMatch
24+
from eureka_ml_insights.metrics.reports import (
25+
BiLevelAggregator,
26+
CountAggregator,
27+
)
3928

4029
from ..configs.config import (
4130
AggregatorConfig,
@@ -64,14 +53,14 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
6453
data_reader_config=DataSetConfig(
6554
HFDataReader,
6655
{
67-
"path": "pxferna/ARC-AGI-v1",
68-
"split": "test",
56+
"path": "pxferna/ARC-AGI-v1",
57+
"split": "test",
6958
"transform": SequenceTransform(
7059
[
7160
MultiplyTransform(n_repeats=1),
7261
]
7362
),
74-
}
63+
},
7564
),
7665
output_dir=os.path.join(self.log_dir, "data_processing_output"),
7766
)
@@ -100,9 +89,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
10089
{
10190
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
10291
"format": ".jsonl",
103-
"transform": SequenceTransform(
104-
[]
105-
),
92+
"transform": SequenceTransform([]),
10693
},
10794
),
10895
output_dir=os.path.join(self.log_dir, "data_post_processing_output"),
@@ -144,14 +131,15 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
144131
},
145132
),
146133
AggregatorConfig(
147-
CountAggregator,
134+
CountAggregator,
148135
{
149136
"column_names": [
150137
"ExactMatch_result",
151138
],
152139
"normalize": True,
153140
"filename_base": "OverallMetrics_Total",
154-
}),
141+
},
142+
),
155143
],
156144
output_dir=os.path.join(self.log_dir, "eval_report"),
157145
)
@@ -165,14 +153,15 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
165153
"format": ".jsonl",
166154
"transform": SequenceTransform(
167155
[
168-
CopyColumn(
156+
CopyColumn(
169157
column_name_src="ExactMatch_result",
170158
column_name_dst="ExactMatch_result_numeric",
171159
),
172-
ReplaceStringsTransform(
160+
ReplaceStringsTransform(
173161
columns=["ExactMatch_result_numeric"],
174-
mapping={'incorrect': '0', 'correct': '1', 'none': 'NaN'},
175-
case=False)
162+
mapping={"incorrect": "0", "correct": "1", "none": "NaN"},
163+
case=False,
164+
),
176165
]
177166
),
178167
},
@@ -186,30 +175,29 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
186175
DataReader,
187176
{
188177
"path": os.path.join(self.posteval_data_post_processing_comp.output_dir, "transformed_data.jsonl"),
189-
"format": ".jsonl"
178+
"format": ".jsonl",
190179
},
191180
),
192181
aggregator_configs=[
193182
AggregatorConfig(
194-
BiLevelAggregator,
183+
BiLevelAggregator,
195184
{
196185
"column_names": [
197186
"ExactMatch_result_numeric",
198187
],
199188
"first_groupby": "data_point_id",
200189
"filename_base": "ExactMatch_Total_BestOfN",
201-
"agg_fn": "max"
202-
}),
190+
"agg_fn": "max",
191+
},
192+
),
203193
AggregatorConfig(
204194
BiLevelAggregator,
205195
{
206-
"column_names": [
207-
"ExactMatch_result_numeric"
208-
],
196+
"column_names": ["ExactMatch_result_numeric"],
209197
"first_groupby": "data_point_id",
210198
"second_groupby": "split",
211199
"filename_base": "ExactMatch_Grouped_by_Split_BestOfN",
212-
"agg_fn": "max"
200+
"agg_fn": "max",
213201
},
214202
),
215203
],
@@ -301,7 +289,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs):
301289
MultiplyTransform(n_repeats=1),
302290
]
303291
),
304-
}
292+
},
305293
)
306294

307295
return config

0 commit comments

Comments
 (0)