Skip to content

Commit 4198c42

Browse files
committed
Add unit tests
1 parent 97301d3 commit 4198c42

File tree

2 files changed

+142
-2
lines changed

2 files changed

+142
-2
lines changed

eureka_ml_insights/user_configs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from .arc_agi import (
99
ARC_AGI_v1_PIPELINE,
1010
ARC_AGI_v1_PIPELINE_5Run,
11-
Phi_ARC_AGI_v1_PIPELINE,
12-
Phi_ARC_AGI_v1_PIPELINE_5Run,
11+
COT_ARC_AGI_v1_PIPELINE,
12+
COT_ARC_AGI_v1_PIPELINE_5Run,
1313
)
1414
from .ba_calendar import (
1515
BA_Calendar_Parallel_PIPELINE,

tests/pipeline_tests.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from eureka_ml_insights.user_configs import (
2525
AIME_PIPELINE,
2626
AIME_SEQ_PIPELINE,
27+
ARC_AGI_v1_PIPELINE,
28+
ARC_AGI_v1_PIPELINE_5Run,
29+
COT_ARC_AGI_v1_PIPELINE,
30+
COT_ARC_AGI_v1_PIPELINE_5Run,
2731
DNA_PIPELINE,
2832
GEOMETER_PIPELINE,
2933
GSM8K_PIPELINE,
@@ -400,6 +404,38 @@ def configure_pipeline(self):
400404
return config
401405

402406

407+
class TEST_ARC_AGI_v1_PIPELINE(ARC_AGI_v1_PIPELINE):
408+
# Test config the BA Calendar benchmark with TestModel and TestDataLoader
409+
def configure_pipeline(self):
410+
config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {}))
411+
self.data_processing_comp.data_reader_config.class_name = TestHFDataReader
412+
return config
413+
414+
415+
class TEST_ARC_AGI_v1_PIPELINE_5Run(ARC_AGI_v1_PIPELINE_5Run):
416+
# Test config the BA Calendar benchmark with TestModel and TestDataLoader
417+
def configure_pipeline(self):
418+
config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {}))
419+
self.data_processing_comp.data_reader_config.class_name = TestHFDataReader
420+
return config
421+
422+
423+
class TEST_COT_ARC_AGI_v1_PIPELINE(COT_ARC_AGI_v1_PIPELINE):
424+
# Test config the BA Calendar benchmark with TestModel and TestDataLoader
425+
def configure_pipeline(self):
426+
config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {}))
427+
self.data_processing_comp.data_reader_config.class_name = TestHFDataReader
428+
return config
429+
430+
431+
class TEST_COT_ARC_AGI_v1_PIPELINE_5Run(COT_ARC_AGI_v1_PIPELINE_5Run):
432+
# Test config the BA Calendar benchmark with TestModel and TestDataLoader
433+
def configure_pipeline(self):
434+
config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {}))
435+
self.data_processing_comp.data_reader_config.class_name = TestHFDataReader
436+
return config
437+
438+
403439
class PipelineTest:
404440
def setUp(self) -> None:
405441
self.conf = self.get_config()
@@ -661,5 +697,109 @@ def get_config(self):
661697
return TEST_GSMSYMBOLIC_PIPELINE().pipeline_config
662698

663699

700+
class ARC_AGI_v1_PipelineTest(PipelineTest, unittest.TestCase):
701+
def get_config(self):
702+
self.test_pipeline = TEST_ARC_AGI_v1_PIPELINE()
703+
self.config = self.test_pipeline.pipeline_config
704+
return self.config
705+
706+
def setUp(self) -> None:
707+
super().setUp()
708+
self.eval_configs = [
709+
self.test_pipeline.evalreporting_comp,
710+
self.test_pipeline.best_of_n_evalreporting_comp
711+
]
712+
713+
def test_outputs_exist(self) -> None:
714+
logging.info("Running test_outputs_exist test in PipelineTest")
715+
self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files))
716+
if self.data_reader_config.prompt_template_path:
717+
self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files))
718+
self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files))
719+
if self.eval_config.metric_config is not None:
720+
self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files))
721+
n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs])
722+
n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)])
723+
self.assertEqual(n_aggregators, n_aggregator_files)
724+
725+
726+
class ARC_AGI_v1_Pipeline_5RunTest(PipelineTest, unittest.TestCase):
727+
def get_config(self):
728+
self.test_pipeline = TEST_ARC_AGI_v1_PIPELINE_5Run()
729+
self.config = self.test_pipeline.pipeline_config
730+
return self.config
731+
732+
def setUp(self) -> None:
733+
super().setUp()
734+
self.eval_configs = [
735+
self.test_pipeline.evalreporting_comp,
736+
self.test_pipeline.best_of_n_evalreporting_comp
737+
]
738+
739+
def test_outputs_exist(self) -> None:
740+
logging.info("Running test_outputs_exist test in PipelineTest")
741+
self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files))
742+
if self.data_reader_config.prompt_template_path:
743+
self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files))
744+
self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files))
745+
if self.eval_config.metric_config is not None:
746+
self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files))
747+
n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs])
748+
n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)])
749+
self.assertEqual(n_aggregators, n_aggregator_files)
750+
751+
752+
class COT_ARC_AGI_v1_PIPELINETest(PipelineTest, unittest.TestCase):
753+
def get_config(self):
754+
self.test_pipeline = TEST_COT_ARC_AGI_v1_PIPELINE()
755+
self.config = self.test_pipeline.pipeline_config
756+
return self.config
757+
758+
def setUp(self) -> None:
759+
super().setUp()
760+
self.eval_configs = [
761+
self.test_pipeline.evalreporting_comp,
762+
self.test_pipeline.best_of_n_evalreporting_comp
763+
]
764+
765+
def test_outputs_exist(self) -> None:
766+
logging.info("Running test_outputs_exist test in PipelineTest")
767+
self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files))
768+
if self.data_reader_config.prompt_template_path:
769+
self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files))
770+
self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files))
771+
if self.eval_config.metric_config is not None:
772+
self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files))
773+
n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs])
774+
n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)])
775+
self.assertEqual(n_aggregators, n_aggregator_files)
776+
777+
778+
class COT_ARC_AGI_v1_PIPELINE_5RunTest(PipelineTest, unittest.TestCase):
779+
def get_config(self):
780+
self.test_pipeline = TEST_COT_ARC_AGI_v1_PIPELINE_5Run()
781+
self.config = self.test_pipeline.pipeline_config
782+
return self.config
783+
784+
def setUp(self) -> None:
785+
super().setUp()
786+
self.eval_configs = [
787+
self.test_pipeline.evalreporting_comp,
788+
self.test_pipeline.best_of_n_evalreporting_comp
789+
]
790+
791+
def test_outputs_exist(self) -> None:
792+
logging.info("Running test_outputs_exist test in PipelineTest")
793+
self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files))
794+
if self.data_reader_config.prompt_template_path:
795+
self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files))
796+
self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files))
797+
if self.eval_config.metric_config is not None:
798+
self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files))
799+
n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs])
800+
n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)])
801+
self.assertEqual(n_aggregators, n_aggregator_files)
802+
803+
664804
if __name__ == "__main__":
665805
unittest.main()

0 commit comments

Comments
 (0)