Skip to content

Commit e0874db

Browse files
committed
* reformat existing configs
+ add new raw data reader and writer buffer * support raw data reader and writer in RftDataset - remove db exporting in RftDataset and ActiveIterator
1 parent f862e11 commit e0874db

File tree

20 files changed

+284
-382
lines changed

20 files changed

+284
-382
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"requests",
4040
"tensorboard",
4141
"openai",
42+
"jsonlines",
4243
]
4344

4445
[project.scripts]

tests/data/controllers/task_parser_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from agentscope.models import DashScopeChatWrapper
77
from loguru import logger
88

9-
from trinity.common.config import Config
9+
from trinity.common.config import DataPipelineConfig
1010
from trinity.data.controllers.task_parser import DataTaskParser
1111

1212

@@ -50,18 +50,18 @@ def _run_test(self, rft_config, return_none=False):
5050
logger.info(op_weights)
5151

5252
def test_instruction1(self):
53-
rft_config = Config()
54-
rft_config.data.dj_process_desc = "Please recommend a data filtering strategy for me."
53+
rft_config = DataPipelineConfig()
54+
rft_config.dj_process_desc = "Please recommend a data filtering strategy for me."
5555
self._run_test(rft_config)
5656

5757
def test_instruction2(self):
58-
rft_config = Config()
59-
rft_config.data.dj_process_desc = "Do nothing."
58+
rft_config = DataPipelineConfig()
59+
rft_config.dj_process_desc = "Do nothing."
6060
self._run_test(rft_config, return_none=True)
6161

6262
def test_instruction3(self):
63-
rft_config = Config()
64-
rft_config.data.dj_process_desc = "Remove samples with repeat contents."
63+
rft_config = DataPipelineConfig()
64+
rft_config.dj_process_desc = "Remove samples with repeat contents."
6565
self._run_test(rft_config)
6666

6767

tests/data/core/dataset_test.py

Lines changed: 27 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import os
44
import unittest
55

6-
from trinity.common.config import DataProcessorConfig, FormatConfig
6+
from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig
77
from trinity.common.rewards import AccuracyReward
8-
from trinity.common.task import TaskSet
98
from trinity.common.workflows import MathWorkflow, SimpleWorkflow
109
from trinity.data.core.dataset import RewardSchema, RftDataset
1110
from trinity.data.core.formatter import BoxedMathAnswerFormatter, RLHFFormatter
@@ -15,28 +14,34 @@ class TestRftDataset(unittest.TestCase):
1514
"""Test cases for RftDataset"""
1615

1716
def setUp(self) -> None:
18-
self.data_config = DataProcessorConfig(
19-
source_data_path=os.path.join(
20-
os.path.dirname(os.path.realpath(__file__)),
21-
"..",
22-
"..",
23-
"test_data",
24-
"test_10",
25-
),
17+
self.data_pipeline_config = DataPipelineConfig(
18+
input_buffers=[StorageConfig(
19+
path=os.path.join(
20+
os.path.dirname(os.path.realpath(__file__)),
21+
"..",
22+
"..",
23+
"test_data",
24+
"test_10",
25+
),
26+
raw=True,
27+
)],
2628
format=FormatConfig(
2729
prompt_key="problem",
2830
response_key="solution",
2931
solution_key="solution",
3032
),
3133
)
32-
self.data_config_sample_level_setting = DataProcessorConfig(
33-
source_data_path=os.path.join(
34-
os.path.dirname(os.path.realpath(__file__)),
35-
"..",
36-
"..",
37-
"test_data",
38-
"test_10_with_rewfn_workflow",
39-
),
34+
self.data_pipeline_config_sample_level_setting = DataPipelineConfig(
35+
input_buffers=[StorageConfig(
36+
path=os.path.join(
37+
os.path.dirname(os.path.realpath(__file__)),
38+
"..",
39+
"..",
40+
"test_data",
41+
"test_10_with_rewfn_workflow",
42+
),
43+
raw=True,
44+
)],
4045
format=FormatConfig(
4146
prompt_key="problem",
4247
response_key="solution",
@@ -47,13 +52,13 @@ def setUp(self) -> None:
4752
)
4853

4954
def test_rft_dataset_init(self):
50-
dataset = RftDataset(data_config=self.data_config, reward_schema="default")
55+
dataset = RftDataset(data_pipeline_config=self.data_pipeline_config, reward_schema="default")
5156

5257
self.assertEqual(len(dataset), 10)
5358
self.assertIsInstance(dataset.reward_schema, RewardSchema)
5459

5560
def test_format_dataset(self):
56-
dataset = RftDataset(data_config=self.data_config, reward_schema="default")
61+
dataset = RftDataset(data_pipeline_config=self.data_pipeline_config, reward_schema="default")
5762
original_data = dataset.data
5863
# no formatter
5964
dataset.format(formatters=[])
@@ -62,56 +67,12 @@ def test_format_dataset(self):
6267
# apply formatters
6368
dataset.format(
6469
formatters=[
65-
BoxedMathAnswerFormatter(config=self.data_config.format),
66-
RLHFFormatter(config=self.data_config.format),
70+
BoxedMathAnswerFormatter(config=self.data_pipeline_config.format),
71+
RLHFFormatter(config=self.data_pipeline_config.format),
6772
]
6873
)
6974
self.assertNotEqual(dataset.data, original_data)
7075

71-
def test_to_taskset(self):
72-
dataset = RftDataset(data_config=self.data_config, reward_schema="default")
73-
taskset = dataset.to_taskset()
74-
self.assertIsInstance(taskset, TaskSet)
75-
self.assertEqual(len(taskset), 10)
76-
self.assertIsNone(taskset.reward_fn)
77-
self.assertIsNone(taskset.workflow)
78-
self.assertEqual(taskset._index, 0)
79-
80-
def test_to_taskset_with_global_settings(self):
81-
dataset = RftDataset(data_config=self.data_config, reward_schema="default")
82-
taskset = dataset.to_taskset(
83-
reward_fn=AccuracyReward,
84-
workflow=SimpleWorkflow,
85-
)
86-
self.assertIsInstance(taskset, TaskSet)
87-
self.assertEqual(taskset.workflow, SimpleWorkflow)
88-
self.assertEqual(taskset.reward_fn, AccuracyReward)
89-
90-
def test_to_taskset_with_sample_level_settings(self):
91-
dataset = RftDataset(
92-
data_config=self.data_config_sample_level_setting, reward_schema="default"
93-
)
94-
taskset = dataset.to_taskset()
95-
self.assertIsInstance(taskset, TaskSet)
96-
for task in taskset.tasks:
97-
self.assertEqual(task.workflow, MathWorkflow)
98-
self.assertEqual(task.reward_fn, AccuracyReward)
99-
100-
def test_to_taskset_with_both_settings(self):
101-
dataset = RftDataset(
102-
data_config=self.data_config_sample_level_setting, reward_schema="default"
103-
)
104-
taskset = dataset.to_taskset(
105-
reward_fn=AccuracyReward,
106-
workflow=SimpleWorkflow,
107-
)
108-
self.assertIsInstance(taskset, TaskSet)
109-
for task in taskset.tasks:
110-
self.assertEqual(task.workflow, MathWorkflow)
111-
self.assertEqual(task.reward_fn, AccuracyReward)
112-
self.assertEqual(taskset.workflow, SimpleWorkflow)
113-
self.assertEqual(taskset.reward_fn, AccuracyReward)
114-
11576

11677
if __name__ == "__main__":
11778
unittest.main()

tests/data/core/formatter_test.py

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import unittest
55

6-
from trinity.common.config import DataProcessorConfig, FormatConfig
6+
from trinity.common.config import DataPipelineConfig, FormatConfig, StorageConfig
77
from trinity.data.core.dataset import RftDataset
88
from trinity.data.core.formatter import (
99
BoxedMathAnswerFormatter,
@@ -18,14 +18,17 @@ class TestBoxedMathDataset(unittest.TestCase):
1818
"""Test cases for RftDataset"""
1919

2020
def setUp(self) -> None:
21-
self.data_config = DataProcessorConfig(
22-
source_data_path=os.path.join(
23-
os.path.dirname(os.path.realpath(__file__)),
24-
"..",
25-
"..",
26-
"test_data",
27-
"test_10",
28-
),
21+
self.data_config = DataPipelineConfig(
22+
input_buffers=[StorageConfig(
23+
path=os.path.join(
24+
os.path.dirname(os.path.realpath(__file__)),
25+
"..",
26+
"..",
27+
"test_data",
28+
"test_10",
29+
),
30+
raw=True,
31+
)],
2932
format=FormatConfig(
3033
prompt_key="problem",
3134
response_key="answer",
@@ -43,12 +46,12 @@ def test_init(self):
4346
self.assertEqual(formatter.config.chat_template, "User: {}\nAssistant: ")
4447
# test for default configs
4548
self.assertEqual(formatter.config.reward_key, "")
46-
self.assertEqual(formatter.config.chosen_key, "")
47-
self.assertEqual(formatter.config.rejected_key, "")
49+
self.assertEqual(formatter.config.chosen_key, "chosen")
50+
self.assertEqual(formatter.config.rejected_key, "rejected")
4851
self.assertEqual(formatter.config.label_key, "")
4952

5053
def test_transform(self):
51-
dataset = RftDataset(data_config=self.data_config, reward_schema="default")
54+
dataset = RftDataset(data_pipeline_config=self.data_config, reward_schema="default")
5255
formatter = BoxedMathAnswerFormatter(config=self.data_config.format)
5356
self.assertNotIn(formatter.config.response_key, dataset.data.column_names)
5457
dataset.format(formatter)
@@ -59,14 +62,17 @@ class TestRLHFFormatter(unittest.TestCase):
5962
"""Test cases for RLHFFormatter"""
6063

6164
def setUp(self) -> None:
62-
self.data_config = DataProcessorConfig(
63-
source_data_path=os.path.join(
64-
os.path.dirname(os.path.realpath(__file__)),
65-
"..",
66-
"..",
67-
"test_data",
68-
"test_10",
69-
),
65+
self.data_config = DataPipelineConfig(
66+
input_buffers=[StorageConfig(
67+
path=os.path.join(
68+
os.path.dirname(os.path.realpath(__file__)),
69+
"..",
70+
"..",
71+
"test_data",
72+
"test_10",
73+
),
74+
raw=True,
75+
)],
7076
format=FormatConfig(
7177
prompt_key="problem",
7278
chat_template="User: {}\nAssistant: ",
@@ -107,14 +113,17 @@ class TestRewardFormatter(unittest.TestCase):
107113
"""Test cases for RewardFormatter"""
108114

109115
def setUp(self) -> None:
110-
self.data_config = DataProcessorConfig(
111-
source_data_path=os.path.join(
112-
os.path.dirname(os.path.realpath(__file__)),
113-
"..",
114-
"..",
115-
"test_data",
116-
"test_10",
117-
),
116+
self.data_config = DataPipelineConfig(
117+
input_buffers=[StorageConfig(
118+
path=os.path.join(
119+
os.path.dirname(os.path.realpath(__file__)),
120+
"..",
121+
"..",
122+
"test_data",
123+
"test_10",
124+
),
125+
raw=True,
126+
)],
118127
format=FormatConfig(
119128
prompt_key="problem",
120129
chosen_key="chosen",
@@ -164,14 +173,17 @@ class TestSFTFormatter(unittest.TestCase):
164173
"""Test cases for SFTFormatter"""
165174

166175
def setUp(self) -> None:
167-
self.data_config = DataProcessorConfig(
168-
source_data_path=os.path.join(
169-
os.path.dirname(os.path.realpath(__file__)),
170-
"..",
171-
"..",
172-
"test_data",
173-
"test_10",
174-
),
176+
self.data_config = DataPipelineConfig(
177+
input_buffers=[StorageConfig(
178+
path=os.path.join(
179+
os.path.dirname(os.path.realpath(__file__)),
180+
"..",
181+
"..",
182+
"test_data",
183+
"test_10",
184+
),
185+
raw=True,
186+
)],
175187
format=FormatConfig(
176188
prompt_key="problem",
177189
response_key="answer",
@@ -217,14 +229,17 @@ class TestComposedFormatter(unittest.TestCase):
217229
"""Test cases for ComposedFormatter"""
218230

219231
def setUp(self) -> None:
220-
self.data_config = DataProcessorConfig(
221-
source_data_path=os.path.join(
222-
os.path.dirname(os.path.realpath(__file__)),
223-
"..",
224-
"..",
225-
"test_data",
226-
"test_10",
227-
),
232+
self.data_config = DataPipelineConfig(
233+
input_buffers=[StorageConfig(
234+
path=os.path.join(
235+
os.path.dirname(os.path.realpath(__file__)),
236+
"..",
237+
"..",
238+
"test_data",
239+
"test_10",
240+
),
241+
raw=True,
242+
)],
228243
format=FormatConfig(
229244
prompt_key="problem",
230245
response_key="answer",

0 commit comments

Comments
 (0)