33import os
44import unittest
55
6- from trinity .common .config import DataProcessorConfig , FormatConfig
6+ from trinity .common .config import DataPipelineConfig , FormatConfig , StorageConfig
77from trinity .common .rewards import AccuracyReward
8- from trinity .common .task import TaskSet
98from trinity .common .workflows import MathWorkflow , SimpleWorkflow
109from trinity .data .core .dataset import RewardSchema , RftDataset
1110from trinity .data .core .formatter import BoxedMathAnswerFormatter , RLHFFormatter
@@ -15,28 +14,34 @@ class TestRftDataset(unittest.TestCase):
1514 """Test cases for RftDataset"""
1615
1716 def setUp (self ) -> None :
18- self .data_config = DataProcessorConfig (
19- source_data_path = os .path .join (
20- os .path .dirname (os .path .realpath (__file__ )),
21- ".." ,
22- ".." ,
23- "test_data" ,
24- "test_10" ,
25- ),
17+ self .data_pipeline_config = DataPipelineConfig (
18+ input_buffers = [StorageConfig (
19+ path = os .path .join (
20+ os .path .dirname (os .path .realpath (__file__ )),
21+ ".." ,
22+ ".." ,
23+ "test_data" ,
24+ "test_10" ,
25+ ),
26+ raw = True ,
27+ )],
2628 format = FormatConfig (
2729 prompt_key = "problem" ,
2830 response_key = "solution" ,
2931 solution_key = "solution" ,
3032 ),
3133 )
32- self .data_config_sample_level_setting = DataProcessorConfig (
33- source_data_path = os .path .join (
34- os .path .dirname (os .path .realpath (__file__ )),
35- ".." ,
36- ".." ,
37- "test_data" ,
38- "test_10_with_rewfn_workflow" ,
39- ),
34+ self .data_pipeline_config_sample_level_setting = DataPipelineConfig (
35+ input_buffers = [StorageConfig (
36+ path = os .path .join (
37+ os .path .dirname (os .path .realpath (__file__ )),
38+ ".." ,
39+ ".." ,
40+ "test_data" ,
41+ "test_10_with_rewfn_workflow" ,
42+ ),
43+ raw = True ,
44+ )],
4045 format = FormatConfig (
4146 prompt_key = "problem" ,
4247 response_key = "solution" ,
@@ -47,13 +52,13 @@ def setUp(self) -> None:
4752 )
4853
4954 def test_rft_dataset_init (self ):
50- dataset = RftDataset (data_config = self .data_config , reward_schema = "default" )
55+ dataset = RftDataset (data_pipeline_config = self .data_pipeline_config , reward_schema = "default" )
5156
5257 self .assertEqual (len (dataset ), 10 )
5358 self .assertIsInstance (dataset .reward_schema , RewardSchema )
5459
5560 def test_format_dataset (self ):
56- dataset = RftDataset (data_config = self .data_config , reward_schema = "default" )
61+ dataset = RftDataset (data_pipeline_config = self .data_pipeline_config , reward_schema = "default" )
5762 original_data = dataset .data
5863 # no formatter
5964 dataset .format (formatters = [])
@@ -62,56 +67,12 @@ def test_format_dataset(self):
6267 # apply formatters
6368 dataset .format (
6469 formatters = [
65- BoxedMathAnswerFormatter (config = self .data_config .format ),
66- RLHFFormatter (config = self .data_config .format ),
70+ BoxedMathAnswerFormatter (config = self .data_pipeline_config .format ),
71+ RLHFFormatter (config = self .data_pipeline_config .format ),
6772 ]
6873 )
6974 self .assertNotEqual (dataset .data , original_data )
7075
71- def test_to_taskset (self ):
72- dataset = RftDataset (data_config = self .data_config , reward_schema = "default" )
73- taskset = dataset .to_taskset ()
74- self .assertIsInstance (taskset , TaskSet )
75- self .assertEqual (len (taskset ), 10 )
76- self .assertIsNone (taskset .reward_fn )
77- self .assertIsNone (taskset .workflow )
78- self .assertEqual (taskset ._index , 0 )
79-
80- def test_to_taskset_with_global_settings (self ):
81- dataset = RftDataset (data_config = self .data_config , reward_schema = "default" )
82- taskset = dataset .to_taskset (
83- reward_fn = AccuracyReward ,
84- workflow = SimpleWorkflow ,
85- )
86- self .assertIsInstance (taskset , TaskSet )
87- self .assertEqual (taskset .workflow , SimpleWorkflow )
88- self .assertEqual (taskset .reward_fn , AccuracyReward )
89-
90- def test_to_taskset_with_sample_level_settings (self ):
91- dataset = RftDataset (
92- data_config = self .data_config_sample_level_setting , reward_schema = "default"
93- )
94- taskset = dataset .to_taskset ()
95- self .assertIsInstance (taskset , TaskSet )
96- for task in taskset .tasks :
97- self .assertEqual (task .workflow , MathWorkflow )
98- self .assertEqual (task .reward_fn , AccuracyReward )
99-
100- def test_to_taskset_with_both_settings (self ):
101- dataset = RftDataset (
102- data_config = self .data_config_sample_level_setting , reward_schema = "default"
103- )
104- taskset = dataset .to_taskset (
105- reward_fn = AccuracyReward ,
106- workflow = SimpleWorkflow ,
107- )
108- self .assertIsInstance (taskset , TaskSet )
109- for task in taskset .tasks :
110- self .assertEqual (task .workflow , MathWorkflow )
111- self .assertEqual (task .reward_fn , AccuracyReward )
112- self .assertEqual (taskset .workflow , SimpleWorkflow )
113- self .assertEqual (taskset .reward_fn , AccuracyReward )
114-
11576
11677if __name__ == "__main__" :
11778 unittest .main ()
0 commit comments