|
24 | 24 | from eureka_ml_insights.user_configs import ( |
25 | 25 | AIME_PIPELINE, |
26 | 26 | AIME_SEQ_PIPELINE, |
| 27 | + ARC_AGI_v1_PIPELINE, |
| 28 | + ARC_AGI_v1_PIPELINE_5Run, |
| 29 | + COT_ARC_AGI_v1_PIPELINE, |
| 30 | + COT_ARC_AGI_v1_PIPELINE_5Run, |
27 | 31 | DNA_PIPELINE, |
28 | 32 | GEOMETER_PIPELINE, |
29 | 33 | GSM8K_PIPELINE, |
@@ -400,6 +404,38 @@ def configure_pipeline(self): |
400 | 404 | return config |
401 | 405 |
|
402 | 406 |
|
| 407 | +class TEST_ARC_AGI_v1_PIPELINE(ARC_AGI_v1_PIPELINE): |
| 408 | + # Test config the BA Calendar benchmark with TestModel and TestDataLoader |
| 409 | + def configure_pipeline(self): |
| 410 | + config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) |
| 411 | + self.data_processing_comp.data_reader_config.class_name = TestHFDataReader |
| 412 | + return config |
| 413 | + |
| 414 | + |
| 415 | +class TEST_ARC_AGI_v1_PIPELINE_5Run(ARC_AGI_v1_PIPELINE_5Run): |
| 416 | + # Test config the BA Calendar benchmark with TestModel and TestDataLoader |
| 417 | + def configure_pipeline(self): |
| 418 | + config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) |
| 419 | + self.data_processing_comp.data_reader_config.class_name = TestHFDataReader |
| 420 | + return config |
| 421 | + |
| 422 | + |
| 423 | +class TEST_COT_ARC_AGI_v1_PIPELINE(COT_ARC_AGI_v1_PIPELINE): |
| 424 | + # Test config the BA Calendar benchmark with TestModel and TestDataLoader |
| 425 | + def configure_pipeline(self): |
| 426 | + config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) |
| 427 | + self.data_processing_comp.data_reader_config.class_name = TestHFDataReader |
| 428 | + return config |
| 429 | + |
| 430 | + |
| 431 | +class TEST_COT_ARC_AGI_v1_PIPELINE_5Run(COT_ARC_AGI_v1_PIPELINE_5Run): |
| 432 | + # Test config the BA Calendar benchmark with TestModel and TestDataLoader |
| 433 | + def configure_pipeline(self): |
| 434 | + config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) |
| 435 | + self.data_processing_comp.data_reader_config.class_name = TestHFDataReader |
| 436 | + return config |
| 437 | + |
| 438 | + |
403 | 439 | class PipelineTest: |
404 | 440 | def setUp(self) -> None: |
405 | 441 | self.conf = self.get_config() |
@@ -661,5 +697,109 @@ def get_config(self): |
661 | 697 | return TEST_GSMSYMBOLIC_PIPELINE().pipeline_config |
662 | 698 |
|
663 | 699 |
|
| 700 | +class ARC_AGI_v1_PipelineTest(PipelineTest, unittest.TestCase): |
| 701 | + def get_config(self): |
| 702 | + self.test_pipeline = TEST_ARC_AGI_v1_PIPELINE() |
| 703 | + self.config = self.test_pipeline.pipeline_config |
| 704 | + return self.config |
| 705 | + |
| 706 | + def setUp(self) -> None: |
| 707 | + super().setUp() |
| 708 | + self.eval_configs = [ |
| 709 | + self.test_pipeline.evalreporting_comp, |
| 710 | + self.test_pipeline.best_of_n_evalreporting_comp |
| 711 | + ] |
| 712 | + |
| 713 | + def test_outputs_exist(self) -> None: |
| 714 | + logging.info("Running test_outputs_exist test in PipelineTest") |
| 715 | + self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files)) |
| 716 | + if self.data_reader_config.prompt_template_path: |
| 717 | + self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files)) |
| 718 | + self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files)) |
| 719 | + if self.eval_config.metric_config is not None: |
| 720 | + self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files)) |
| 721 | + n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs]) |
| 722 | + n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)]) |
| 723 | + self.assertEqual(n_aggregators, n_aggregator_files) |
| 724 | + |
| 725 | + |
| 726 | +class ARC_AGI_v1_Pipeline_5RunTest(PipelineTest, unittest.TestCase): |
| 727 | + def get_config(self): |
| 728 | + self.test_pipeline = TEST_ARC_AGI_v1_PIPELINE_5Run() |
| 729 | + self.config = self.test_pipeline.pipeline_config |
| 730 | + return self.config |
| 731 | + |
| 732 | + def setUp(self) -> None: |
| 733 | + super().setUp() |
| 734 | + self.eval_configs = [ |
| 735 | + self.test_pipeline.evalreporting_comp, |
| 736 | + self.test_pipeline.best_of_n_evalreporting_comp |
| 737 | + ] |
| 738 | + |
| 739 | + def test_outputs_exist(self) -> None: |
| 740 | + logging.info("Running test_outputs_exist test in PipelineTest") |
| 741 | + self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files)) |
| 742 | + if self.data_reader_config.prompt_template_path: |
| 743 | + self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files)) |
| 744 | + self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files)) |
| 745 | + if self.eval_config.metric_config is not None: |
| 746 | + self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files)) |
| 747 | + n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs]) |
| 748 | + n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)]) |
| 749 | + self.assertEqual(n_aggregators, n_aggregator_files) |
| 750 | + |
| 751 | + |
| 752 | +class COT_ARC_AGI_v1_PIPELINETest(PipelineTest, unittest.TestCase): |
| 753 | + def get_config(self): |
| 754 | + self.test_pipeline = TEST_COT_ARC_AGI_v1_PIPELINE() |
| 755 | + self.config = self.test_pipeline.pipeline_config |
| 756 | + return self.config |
| 757 | + |
| 758 | + def setUp(self) -> None: |
| 759 | + super().setUp() |
| 760 | + self.eval_configs = [ |
| 761 | + self.test_pipeline.evalreporting_comp, |
| 762 | + self.test_pipeline.best_of_n_evalreporting_comp |
| 763 | + ] |
| 764 | + |
| 765 | + def test_outputs_exist(self) -> None: |
| 766 | + logging.info("Running test_outputs_exist test in PipelineTest") |
| 767 | + self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files)) |
| 768 | + if self.data_reader_config.prompt_template_path: |
| 769 | + self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files)) |
| 770 | + self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files)) |
| 771 | + if self.eval_config.metric_config is not None: |
| 772 | + self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files)) |
| 773 | + n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs]) |
| 774 | + n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)]) |
| 775 | + self.assertEqual(n_aggregators, n_aggregator_files) |
| 776 | + |
| 777 | + |
| 778 | +class COT_ARC_AGI_v1_PIPELINE_5RunTest(PipelineTest, unittest.TestCase): |
| 779 | + def get_config(self): |
| 780 | + self.test_pipeline = TEST_COT_ARC_AGI_v1_PIPELINE_5Run() |
| 781 | + self.config = self.test_pipeline.pipeline_config |
| 782 | + return self.config |
| 783 | + |
| 784 | + def setUp(self) -> None: |
| 785 | + super().setUp() |
| 786 | + self.eval_configs = [ |
| 787 | + self.test_pipeline.evalreporting_comp, |
| 788 | + self.test_pipeline.best_of_n_evalreporting_comp |
| 789 | + ] |
| 790 | + |
| 791 | + def test_outputs_exist(self) -> None: |
| 792 | + logging.info("Running test_outputs_exist test in PipelineTest") |
| 793 | + self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files)) |
| 794 | + if self.data_reader_config.prompt_template_path: |
| 795 | + self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files)) |
| 796 | + self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files)) |
| 797 | + if self.eval_config.metric_config is not None: |
| 798 | + self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files)) |
| 799 | + n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs]) |
| 800 | + n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)]) |
| 801 | + self.assertEqual(n_aggregators, n_aggregator_files) |
| 802 | + |
| 803 | + |
664 | 804 | if __name__ == "__main__": |
665 | 805 | unittest.main() |
0 commit comments