Skip to content

Commit 727404e

Browse files
committed
Add best of n scoring
1 parent 162a2c6 commit 727404e

File tree

1 file changed

+123
-10
lines changed

1 file changed

+123
-10
lines changed

eureka_ml_insights/user_configs/arc_agi.py

Lines changed: 123 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -359,25 +359,45 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
359359
metric_config=MetricConfig(ExactMatch),
360360
aggregator_configs=[
361361
AggregatorConfig(
362-
CountAggregator,
362+
CountAggregator,
363363
{
364364
"column_names": [
365365
"ExactMatch_result",
366366
],
367-
"filename_base": "OverallMetrics_Separate_Runs_Grouped",
367+
"group_by": "data_repeat_id",
368368
"normalize": True,
369-
"group_by": "split",
370-
},
371-
),
369+
"filename_base": "ExactMatch_SeparateRuns",
370+
}),
372371
AggregatorConfig(
373-
CountAggregator,
372+
CountAggregator,
374373
{
375374
"column_names": [
376375
"ExactMatch_result",
377376
],
377+
"group_by":["data_repeat_id", "split"],
378+
"filename_base": "ExactMatch_GroupBy_DatasetSplit_SeparateRuns",
378379
"normalize": True,
379-
"filename_base": "OverallMetrics_Separate_Runs_Total",
380-
}),
380+
},
381+
),
382+
AggregatorConfig(
383+
BiLevelCountAggregator,
384+
{
385+
"column_names": ["ExactMatch_result"],
386+
"first_groupby": "data_repeat_id",
387+
"filename_base": "ExactMatch_AllRuns",
388+
"normalize": True,
389+
},
390+
),
391+
AggregatorConfig(
392+
BiLevelCountAggregator,
393+
{
394+
"column_names": ["ExactMatch_result"],
395+
"first_groupby": ["data_repeat_id", "split"],
396+
"second_groupby": "split",
397+
"filename_base": "ExactMatch_GroupBy_DatasetSplit_AllRuns",
398+
"normalize": True,
399+
},
400+
),
381401
],
382402
output_dir=os.path.join(self.log_dir, "eval_report"),
383403
)
@@ -423,7 +443,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
423443
"ExactMatch_result_numeric",
424444
],
425445
"first_groupby": "uid",
426-
"filename_base": "ExactMatch_Total_BestOfN",
446+
"filename_base": "ExactMatch_BestOfN",
427447
}),
428448
# the first three reports aggregate results by data_point_id and take the best out of N
429449
AggregatorConfig(
@@ -434,14 +454,104 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
434454
],
435455
"first_groupby": "uid",
436456
"second_groupby": "split",
437-
"filename_base": "ExactMatch_Grouped_BestOfN",
457+
"filename_base": "ExactMatch_GroupBy_DatasetSplit_BestOfN",
438458
"agg_fn": "max"
439459
},
440460
),
441461
],
442462
output_dir=os.path.join(self.log_dir, "bestofn_eval_report"),
443463
)
444464

465+
self.worst_of_n_evalreporting_comp = EvalReportingConfig(
466+
component_type=EvalReporting,
467+
data_reader_config=DataSetConfig(
468+
DataReader,
469+
{
470+
"path": os.path.join(self.posteval_data_post_processing_comp.output_dir, "transformed_data.jsonl"),
471+
"format": ".jsonl"
472+
},
473+
),
474+
aggregator_configs=[
475+
AggregatorConfig(
476+
BiLevelAggregator,
477+
{
478+
"column_names": [
479+
"ExactMatch_result_numeric",
480+
],
481+
"first_groupby": "uid",
482+
"filename_base": "ExactMatch_WorstOfN",
483+
}),
484+
# the first three reports aggregate results by data_point_id and take the best out of N
485+
AggregatorConfig(
486+
BiLevelAggregator,
487+
{
488+
"column_names": [
489+
"ExactMatch_result_numeric"
490+
],
491+
"first_groupby": "uid",
492+
"second_groupby": "split",
493+
"filename_base": "ExactMatch_GroupBy_DatasetSplit_WorstOfN",
494+
"agg_fn": "min"
495+
},
496+
),
497+
],
498+
output_dir=os.path.join(self.log_dir, "worstofn_eval_report"),
499+
)
500+
501+
# aggregate the output by majority vote
502+
self.data_post_processing_mv = DataProcessingConfig(
503+
component_type=DataProcessing,
504+
data_reader_config=DataSetConfig(
505+
DataReader,
506+
{
507+
"path": os.path.join(self.evalreporting_comp.output_dir, "metric_results.jsonl"),
508+
"format": ".jsonl",
509+
"transform": SequenceTransform(
510+
[
511+
MajorityVoteTransform(model_output_col="model_output"),
512+
ColumnRename(
513+
name_mapping={
514+
"model_output": "model_output_onerun",
515+
"majority_vote": "model_output",
516+
}
517+
),
518+
RunPythonTransform("df = df[df['data_repeat_id'] == 'repeat_0']"),
519+
]
520+
),
521+
},
522+
),
523+
output_dir=os.path.join(self.log_dir, "data_post_processing_mv"),
524+
)
525+
526+
self.mv_evalreporting_comp = EvalReportingConfig(
527+
component_type=EvalReporting,
528+
data_reader_config=DataSetConfig(
529+
DataReader,
530+
{
531+
"path": os.path.join(self.data_post_processing_mv.output_dir, "transformed_data.jsonl"),
532+
"format": ".jsonl",
533+
},
534+
),
535+
metric_config=MetricConfig(ExactMatch),
536+
aggregator_configs=[
537+
# these three reports aggregate the metrics for the majority vote results
538+
AggregatorConfig(
539+
CountAggregator,
540+
{"column_names": ["ExactMatch_result"], "filename_base": "MajorityVote", "normalize": True},
541+
),
542+
AggregatorConfig(
543+
CountAggregator,
544+
{
545+
"column_names": ["ExactMatch_result"],
546+
"group_by": ["split"],
547+
"filename_base": "MajorityVote_GroupBy_DatasetSplit",
548+
"normalize": True,
549+
},
550+
),
551+
],
552+
output_dir=os.path.join(self.log_dir, "majorityvote_eval_report"),
553+
)
554+
445555
# Configure the pipeline
446556
return PipelineConfig(
447557
[
@@ -451,6 +561,9 @@ def configure_pipeline(self, model_config=None, resume_from=None, resume_logdir=
451561
self.evalreporting_comp,
452562
self.posteval_data_post_processing_comp,
453563
self.best_of_n_evalreporting_comp,
564+
self.worst_of_n_evalreporting_comp,
565+
self.data_post_processing_mv,
566+
self.mv_evalreporting_comp,
454567
],
455568
self.log_dir,
456569
)

0 commit comments

Comments
 (0)