44from eureka_ml_insights .core import Inference , PromptProcessing
55from eureka_ml_insights .core .data_processing import DataProcessing
66from eureka_ml_insights .core .eval_reporting import EvalReporting
7- from eureka_ml_insights .data_utils .arc_agi_utils import (
8- ARCAGI_ExtractAnswer
9- )
7+ from eureka_ml_insights .data_utils .arc_agi_utils import ARCAGI_ExtractAnswer
108from eureka_ml_insights .data_utils .data import (
119 DataLoader ,
1210 DataReader ,
1311 HFDataReader ,
1412)
15- from eureka_ml_insights .metrics .metrics_base import ExactMatch
16- from eureka_ml_insights .metrics .reports import (
17- CountAggregator ,
18- AverageAggregator ,
19- BiLevelCountAggregator ,
20- BiLevelAggregator ,
21- CountAggregator
22- )
23-
2413from eureka_ml_insights .data_utils .transform import (
2514 AddColumn ,
26- AddColumnAndData ,
2715 CleanCOTAnswer ,
2816 ColumnRename ,
2917 CopyColumn ,
3018 ExtractUsageTransform ,
31- MajorityVoteTransform ,
3219 MultiplyTransform ,
3320 ReplaceStringsTransform ,
34- RunPythonTransform ,
35- SamplerTransform ,
3621 SequenceTransform ,
3722)
38- from eureka_ml_insights .metrics .ba_calendar_metrics import BACalendarMetric
23+ from eureka_ml_insights .metrics .metrics_base import ExactMatch
24+ from eureka_ml_insights .metrics .reports import (
25+ BiLevelAggregator ,
26+ CountAggregator ,
27+ )
3928
4029from ..configs .config import (
4130 AggregatorConfig ,
@@ -64,14 +53,14 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
6453 data_reader_config = DataSetConfig (
6554 HFDataReader ,
6655 {
67- "path" : "pxferna/ARC-AGI-v1" ,
68- "split" : "test" ,
56+ "path" : "pxferna/ARC-AGI-v1" ,
57+ "split" : "test" ,
6958 "transform" : SequenceTransform (
7059 [
7160 MultiplyTransform (n_repeats = 1 ),
7261 ]
7362 ),
74- }
63+ },
7564 ),
7665 output_dir = os .path .join (self .log_dir , "data_processing_output" ),
7766 )
@@ -100,9 +89,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
10089 {
10190 "path" : os .path .join (self .inference_comp .output_dir , "inference_result.jsonl" ),
10291 "format" : ".jsonl" ,
103- "transform" : SequenceTransform (
104- []
105- ),
92+ "transform" : SequenceTransform ([]),
10693 },
10794 ),
10895 output_dir = os .path .join (self .log_dir , "data_post_processing_output" ),
@@ -144,14 +131,15 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
144131 },
145132 ),
146133 AggregatorConfig (
147- CountAggregator ,
134+ CountAggregator ,
148135 {
149136 "column_names" : [
150137 "ExactMatch_result" ,
151138 ],
152139 "normalize" : True ,
153140 "filename_base" : "OverallMetrics_Total" ,
154- }),
141+ },
142+ ),
155143 ],
156144 output_dir = os .path .join (self .log_dir , "eval_report" ),
157145 )
@@ -165,14 +153,15 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
165153 "format" : ".jsonl" ,
166154 "transform" : SequenceTransform (
167155 [
168- CopyColumn (
156+ CopyColumn (
169157 column_name_src = "ExactMatch_result" ,
170158 column_name_dst = "ExactMatch_result_numeric" ,
171159 ),
172- ReplaceStringsTransform (
160+ ReplaceStringsTransform (
173161 columns = ["ExactMatch_result_numeric" ],
174- mapping = {'incorrect' : '0' , 'correct' : '1' , 'none' : 'NaN' },
175- case = False )
162+ mapping = {"incorrect" : "0" , "correct" : "1" , "none" : "NaN" },
163+ case = False ,
164+ ),
176165 ]
177166 ),
178167 },
@@ -186,30 +175,29 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P
186175 DataReader ,
187176 {
188177 "path" : os .path .join (self .posteval_data_post_processing_comp .output_dir , "transformed_data.jsonl" ),
189- "format" : ".jsonl"
178+ "format" : ".jsonl" ,
190179 },
191180 ),
192181 aggregator_configs = [
193182 AggregatorConfig (
194- BiLevelAggregator ,
183+ BiLevelAggregator ,
195184 {
196185 "column_names" : [
197186 "ExactMatch_result_numeric" ,
198187 ],
199188 "first_groupby" : "data_point_id" ,
200189 "filename_base" : "ExactMatch_Total_BestOfN" ,
201- "agg_fn" : "max"
202- }),
190+ "agg_fn" : "max" ,
191+ },
192+ ),
203193 AggregatorConfig (
204194 BiLevelAggregator ,
205195 {
206- "column_names" : [
207- "ExactMatch_result_numeric"
208- ],
196+ "column_names" : ["ExactMatch_result_numeric" ],
209197 "first_groupby" : "data_point_id" ,
210198 "second_groupby" : "split" ,
211199 "filename_base" : "ExactMatch_Grouped_by_Split_BestOfN" ,
212- "agg_fn" : "max"
200+ "agg_fn" : "max" ,
213201 },
214202 ),
215203 ],
@@ -301,7 +289,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs):
301289 MultiplyTransform (n_repeats = 1 ),
302290 ]
303291 ),
304- }
292+ },
305293 )
306294
307295 return config
0 commit comments