44import unittest
55
66from trinity .common .config import DataPipelineConfig , FormatConfig , StorageConfig
7- from trinity .common .rewards import AccuracyReward
8- from trinity .common .workflows import MathWorkflow , SimpleWorkflow
97from trinity .data .core .dataset import RewardSchema , RftDataset
108from trinity .data .core .formatter import BoxedMathAnswerFormatter , RLHFFormatter
119
@@ -15,33 +13,37 @@ class TestRftDataset(unittest.TestCase):
1513
1614 def setUp (self ) -> None :
1715 self .data_pipeline_config = DataPipelineConfig (
18- input_buffers = [StorageConfig (
19- path = os .path .join (
20- os .path .dirname (os .path .realpath (__file__ )),
21- ".." ,
22- ".." ,
23- "test_data" ,
24- "test_10" ,
25- ),
26- raw = True ,
27- )],
16+ input_buffers = [
17+ StorageConfig (
18+ path = os .path .join (
19+ os .path .dirname (os .path .realpath (__file__ )),
20+ ".." ,
21+ ".." ,
22+ "test_data" ,
23+ "test_10" ,
24+ ),
25+ raw = True ,
26+ )
27+ ],
2828 format = FormatConfig (
2929 prompt_key = "problem" ,
3030 response_key = "solution" ,
3131 solution_key = "solution" ,
3232 ),
3333 )
3434 self .data_pipeline_config_sample_level_setting = DataPipelineConfig (
35- input_buffers = [StorageConfig (
36- path = os .path .join (
37- os .path .dirname (os .path .realpath (__file__ )),
38- ".." ,
39- ".." ,
40- "test_data" ,
41- "test_10_with_rewfn_workflow" ,
42- ),
43- raw = True ,
44- )],
35+ input_buffers = [
36+ StorageConfig (
37+ path = os .path .join (
38+ os .path .dirname (os .path .realpath (__file__ )),
39+ ".." ,
40+ ".." ,
41+ "test_data" ,
42+ "test_10_with_rewfn_workflow" ,
43+ ),
44+ raw = True ,
45+ )
46+ ],
4547 format = FormatConfig (
4648 prompt_key = "problem" ,
4749 response_key = "solution" ,
@@ -52,13 +54,17 @@ def setUp(self) -> None:
5254 )
5355
5456 def test_rft_dataset_init (self ):
55- dataset = RftDataset (data_pipeline_config = self .data_pipeline_config , reward_schema = "default" )
57+ dataset = RftDataset (
58+ data_pipeline_config = self .data_pipeline_config , reward_schema = "default"
59+ )
5660
5761 self .assertEqual (len (dataset ), 10 )
5862 self .assertIsInstance (dataset .reward_schema , RewardSchema )
5963
6064 def test_format_dataset (self ):
61- dataset = RftDataset (data_pipeline_config = self .data_pipeline_config , reward_schema = "default" )
65+ dataset = RftDataset (
66+ data_pipeline_config = self .data_pipeline_config , reward_schema = "default"
67+ )
6268 original_data = dataset .data
6369 # no formatter
6470 dataset .format (formatters = [])
0 commit comments