|
4 | 4 | from eureka_ml_insights.core import Inference, PromptProcessing |
5 | 5 | from eureka_ml_insights.core.data_processing import DataProcessing |
6 | 6 | from eureka_ml_insights.core.eval_reporting import EvalReporting |
7 | | -from eureka_ml_insights.data_utils.ba_calendar_utils import ( |
8 | | - BA_Calendar_ExtractAnswer, |
| 7 | +from eureka_ml_insights.data_utils.arc_agi_utils import ( |
| 8 | + ARCAGI_ExtractAnswer, |
9 | 9 | ) |
10 | 10 | from eureka_ml_insights.data_utils.data import ( |
11 | 11 | DataLoader, |
12 | 12 | DataReader, |
13 | 13 | HFDataReader, |
14 | 14 | ) |
15 | | -from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric |
| 15 | +from eureka_ml_insights.metrics.metrics_base import ExactMatch |
16 | 16 | from eureka_ml_insights.metrics.reports import ( |
| 17 | + CountAggregator, |
17 | 18 | AverageAggregator, |
18 | 19 | BiLevelCountAggregator, |
19 | 20 | BiLevelAggregator, |
@@ -84,11 +85,62 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir= |
84 | 85 | if resume_logdir: |
85 | 86 | self.log_dir = resume_from.split("/")[0:len(resume_from.split("/")) - 1] |
86 | 87 |
|
| 88 | + # Configure the evaluation and reporting component for evaluation and dataset level aggregation |
| 89 | + self.evalreporting_comp = EvalReportingConfig( |
| 90 | + component_type=EvalReporting, |
| 91 | + data_reader_config=DataSetConfig( |
| 92 | + DataReader, |
| 93 | + { |
| 94 | + "path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), |
| 95 | + "format": ".jsonl", |
| 96 | + "transform": SequenceTransform( |
| 97 | + [ |
| 98 | + ExtractUsageTransform(model_config), |
| 99 | + ColumnRename( |
| 100 | + name_mapping={ |
| 101 | + "model_output": "raw_output", |
| 102 | + } |
| 103 | + ), |
| 104 | + AddColumn("model_output"), |
| 105 | + ARCAGI_ExtractAnswer("raw_output", "model_output"), |
| 106 | + ] |
| 107 | + ), |
| 108 | + }, |
| 109 | + ), |
| 110 | + metric_config=MetricConfig(ExactMatch), |
| 111 | + aggregator_configs=[ |
| 112 | + AggregatorConfig( |
| 113 | + CountAggregator, |
| 114 | + { |
| 115 | + "column_names": [ |
| 116 | + "ExactMatch_result", |
| 117 | + ], |
| 118 | + "filename_base": "OverallMetrics_Separate_Runs_Grouped", |
| 119 | + "normalize": True, |
| 120 | + "group_by": "split", |
| 121 | + }, |
| 122 | + ), |
| 123 | + # the next three reports take the average and std for all repeats |
| 124 | + # the resulting numbers are the average and std of N pass@1 scores, where N is number of repeats |
| 125 | + AggregatorConfig( |
| 126 | + CountAggregator, |
| 127 | + { |
| 128 | + "column_names": [ |
| 129 | + "ExactMatch_result", |
| 130 | + ], |
| 131 | + "normalize": True, |
| 132 | + "filename_base": "OverallMetrics_Separate_Runs_Total", |
| 133 | + }), |
| 134 | + ], |
| 135 | + output_dir=os.path.join(self.log_dir, "eval_report"), |
| 136 | + ) |
| 137 | + |
87 | 138 | # Configure the pipeline |
88 | 139 | return PipelineConfig( |
89 | 140 | [ |
90 | 141 | self.data_processing_comp, |
91 | 142 | self.inference_comp, |
| 143 | + self.evalreporting_comp, |
92 | 144 | ], |
93 | 145 | self.log_dir, |
94 | 146 | ) |
0 commit comments