Skip to content

Commit 2086bec

Browse files
authored
Merge TaskSet into Buffer (#34)
1 parent 9a29a96 commit 2086bec

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+942
-878
lines changed

docs/sphinx_doc/source/main.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,16 @@ You may customize the configurations in [`examples`](https://github.com/modelsco
186186
model:
187187
model_path: $MODEL_PATH/{model_name}
188188

189-
data:
190-
dataset_path: $DATASET_PATH/{dataset_name}
189+
buffer:
190+
explorer_input:
191+
taskset:
192+
name: $TASKSET_NAME
193+
path: $DATASET_PATH/{dataset_name}
194+
format:
195+
prompt_key: 'question'
196+
response_key: 'answer'
197+
default_workflow_type: $WORKFLOW_NAME
198+
default_reward_fn_type: $REWARD_FN_NAME
191199
```
192200
193201
Please refer to [`examples`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/) for more details.

docs/sphinx_doc/source/tutorial/example_async_mode.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@ In addition, we need to configure the following parameters in both files.
1010
The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks.
1111

1212
```yaml
13-
data:
13+
global_config:
1414
batch_size: <batch_size>
1515
# The same checkpoint path
1616
model:
1717
checkpoint_path: /PATH/TO/CHECKPOINT
1818

1919
# The same data_base path
2020
buffer:
21-
train_dataset:
22-
name: gsm8k_buffer
23-
storage_type: queue
24-
path: 'sqlite:///gsm8k.db'
21+
trainer_input:
22+
experience_buffer:
23+
name: gsm8k_buffer
24+
storage_type: queue
25+
path: 'sqlite:///gsm8k.db'
2526

2627
synchronizer:
2728
sync_method: 'checkpoint'

docs/sphinx_doc/source/tutorial/example_data_functionalities.md

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,52 +31,41 @@ Trinity-RFT uses a unified config file to manage all config items. For the data
3131
In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:
3232

3333
```yaml
34-
data:
34+
data_processor:
3535
# basic info
36-
dataset_path: '/path/to/gsm8k'
37-
dataset_config:
36+
source_data_path: '/path/to/gsm8k'
37+
load_kwargs:
3838
split: 'train' # only need the train split
39-
format_config: # set the field mappings
39+
format: # set the field mappings
4040
prompt_key: 'question'
4141
response_key: 'answer'
4242
# database related. The result dataset will be stored in the database.
4343
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
44-
# downstream loading related
45-
total_epochs: 1
46-
batch_size: 96
47-
default_workflow_type: 'math_workflow'
4844
```
4945
5046
Here you can set the basic information for the GSM-8K dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training:
5147
52-
+ `dataset_path`: the path to the raw dataset.
53-
+ `dataset_config`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library.
54-
+ `format_config`: some dataset format config items, which are used to map original data field names to unified ones.
48+
+ `source_data_path`: the path to the raw dataset.
49+
+ `load_kwargs`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library.
50+
+ `format`: some dataset format config items, which are used to map original data field names to unified ones.
5551
+ `db_url`: the URL of the postgresql database to store the result dataset.
56-
+ `total_epochs`: the total number of epochs to train on this dataset.
57-
+ `batch_size`: the training batch size.
58-
+ `default_workflow_type`: the default exploring workflow type. Please refer to [programming guide](trinity_programming_guide.md) for more details.
5952

6053
In addition, there are several config items related to the data active iterator, which is used to prepare a better dataset. The core part of the data active iterator, Data-Juicer, provides tens of operators to help clean or calculate key information for each sample in the dataset. You can configure this part depending on how familiar you are with Data-Juicer.
6154

6255
#### Not familiar with Data-Juicer
6356
If you are not familiar with Data-Juicer, the data module provides a natural-language-based method to config the data processing recipe. What you need to do is only describe the demands of how you want to prepare for the raw dataset, and an agent will be invoked to arrange the data processing recipe for you. Here is an example:
6457

6558
```yaml
66-
data:
59+
data_processor:
6760
# basic info
68-
dataset_path: '/path/to/gsm8k'
69-
dataset_config:
61+
source_data_path: '/path/to/gsm8k'
62+
load_kwargs:
7063
split: 'train' # only need the train split
71-
format_config: # set the field mappings
64+
format: # set the field mappings
7265
prompt_key: 'question'
7366
response_key: 'answer'
7467
# database related. The result dataset will be stored in the database.
7568
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
76-
# downstream loading related
77-
total_epochs: 1
78-
batch_size: 96
79-
default_workflow_type: 'math_workflow'
8069
8170
#### new part about data active iterator
8271
dj_process_desc: 'Please compute difficulty scores for these math questions.'
@@ -109,20 +98,16 @@ process:
10998
After preparing the Data-Juicer data processing recipe, you can set the `dj_config_path` item in the Trinity-RFT config file to the path to this recipe. For example:
11099

111100
```yaml
112-
data:
101+
data_processor:
113102
# basic info
114-
dataset_path: '/path/to/gsm8k'
115-
dataset_config:
103+
source_data_path: '/path/to/gsm8k'
104+
load_kwargs:
116105
split: 'train' # only need the train split
117-
format_config: # set the field mappings
106+
format: # set the field mappings
118107
prompt_key: 'question'
119108
response_key: 'answer'
120109
# database related. The result dataset will be stored in the database.
121110
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
122-
# downstream loading related
123-
total_epochs: 1
124-
batch_size: 96
125-
default_workflow_type: 'math_workflow'
126111
127112
#### new part about data active iterator
128113
dj_config_path: '/path/to/the/Data-Juicer/data/processing/recipe/above.yaml'
@@ -185,23 +170,19 @@ Trinity-RFT uses a unified config file to manage all config items. For the data
185170
In this example, assume that you need to rank all math questions and corresponding answers by their difficulties. So you can set these config items like the following example:
186171

187172
```yaml
188-
data:
173+
data_processor:
189174
# basic info
190-
dataset_path: 'tests/test_data/test_human_annotator'
191-
dataset_config:
175+
source_data_path: 'tests/test_data/test_human_annotator'
176+
load_kwargs:
192177
split: 'train' # only need the train split
193-
format_config: # set the field mappings
178+
format: # set the field mappings
194179
prompt_key: 'prompt'
195180
chosen_key: 'chosen'
196181
rejected_key: 'rejected'
197182
#### new part about data active iterator
198183
dj_config_path: 'tests/test_configs/human_annotator_test_dj_cfg.yaml'
199184
# database related. The result dataset will be stored in the database.
200185
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
201-
# downstream loading related
202-
total_epochs: 20
203-
batch_size: 32
204-
default_workflow_type: 'math_workflow'
205186
```
206187
207188
Here you can set the basic information for the example dataset, database information that is used to store the result dataset, and some other items about downstream dataset loading for exploring and training, which is similar to the example above.

docs/sphinx_doc/source/tutorial/example_dpo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ buffer:
5151
train_dataset:
5252
storage_type: file
5353
path: <$DATASET_PATH/human_like_dpo_dataset>
54-
kwargs:
54+
format:
5555
prompt_type: <prompt_type> # messages/plaintext
5656
prompt_key: <prompt_key>
5757
chosen_key: <chosen_key>

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ buffer:
8484
sft_warmup_dataset:
8585
storage_type: file
8686
path: <$DATASET_PATH/{sft_data}>
87-
kwargs:
87+
format:
8888
prompt_type: <prompt_type> # messages/plaintext/chatpair
8989
prompt_key: <prompt_key>
9090
response_key: <response_key>

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22

33
The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example.
44

5+
## Global Config
6+
7+
```yaml
8+
mode: both
9+
global_config:
10+
total_epochs: 1
11+
batch_size: 96
12+
eval_interval: 1000
13+
```
14+
15+
- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
16+
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
17+
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
18+
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
19+
520

621
## Monitor
722

@@ -15,45 +30,32 @@ monitor:
1530
- `monitor.name`: The name of the experiment. It must be set manually.
1631

1732

18-
## Data
33+
## Data Processing
1934

2035
<!-- The `data` configuration specifies the data used for training. It includes the total number of epochs, the batch size, the path to the dataset, the default workflow type, the default reward function type, and the format configuration. -->
2136

2237
```yaml
23-
data:
24-
dataset_path: '/PATH/TO/DATASET'
25-
train_split: 'train'
26-
eval_split: ''
27-
dataset_config:
28-
split: 'train'
29-
format_config:
38+
data_processor:
39+
source_data_path: '/PATH/TO/DATASET'
40+
load_kwargs:
41+
split: 'train' # only need the train split
42+
format:
3043
prompt_key: 'question'
3144
response_key: 'answer'
3245
33-
db_url: ''
34-
max_retry_times: 3
35-
max_retry_interval: 1
36-
37-
total_epochs: 20
38-
batch_size: 96
39-
default_workflow_type: 'math_workflow'
40-
default_reward_fn_type: 'countdown_reward'
46+
# cleaner related
47+
dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml'
48+
clean_strategy: 'iterative'
49+
# db related
50+
db_url: 'postgresql://{username}@localhost:5432/{db_name}'
4151
```
4252

43-
- `data.dataset_path`: The path to the dataset.
44-
- `data.train_split`: The split name of the dataset used for training. Default is `train`.
45-
- `data.eval_split`: The split name of the dataset used for eval.
46-
- `data.dataset_config`: The configuration for the dataset. <!-- TODO: may only used in Data-Juicer -->
47-
- `data.format_config`: The configuration for the format of the dataset.
53+
- `data.source_data_path`: The path to the source dataset.
54+
- `data.load_kwargs`: The kwargs used in `datasets.load_dataset`.
55+
- `data.format`: The format of the source dataset. It includes `prompt_key` and `response_key`.
56+
- `data.dj_config_path`: The path to the Data-Juicer configuration.
57+
- `data.clean_strategy`: The cleaning strategy used for `DataCleaner`, which iteratively cleans dataset until targets are met.
4858
- `data.db_url`: The URL of the database.
49-
- `data.max_retry_times`: The maximum number of retries when loading the dataset from database.
50-
- `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database.
51-
- `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
52-
- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `explorer.repeat_times`. It should be set manually.
53-
- `data.default_workflow_type`: The default workflow type used for training.
54-
- `data.default_reward_fn_type`: The default reward function type used for training.
55-
56-
<!-- TODO explain the dataset_config and format_config -->
5759

5860
## Model
5961

@@ -93,18 +95,40 @@ cluster:
9395
buffer:
9496
max_retry_times: 3
9597
max_retry_interval: 1
96-
train_dataset:
97-
name: countdown_buffer
98-
storage_type: queue
99-
algorithm_type: ppo
100-
path: 'sqlite:///countdown.db'
101-
sft_warmup_dataset: null
98+
explorer_input:
99+
taskset:
100+
name: countdown
101+
path: 'countdown_dataset/oneshot-split'
102+
split: train
103+
format:
104+
prompt_key: 'question'
105+
response_key: 'answer'
106+
eval_tasksets: []
107+
default_workflow_type: 'math_workflow'
108+
default_reward_fn_type: 'countdown_reward'
109+
trainer_input:
110+
experience_buffer:
111+
name: countdown_buffer
112+
storage_type: queue
113+
path: 'sqlite:///countdown.db'
114+
sft_warmup_dataset: null
102115
```
103116

104-
- `buffer.max_retry_times`: The maximum number of retries when loading the dataset from database.
105-
- `buffer.max_retry_interval`: The maximum interval between retries when loading the dataset from database.
106-
- `buffer.train_dataset`: The configuration of the training dataset.
107-
- `buffer.sft_warmup_dataset`: The configuration of the SFT warmup dataset.
117+
- `buffer.max_retry_times`: The maximum number of retries when loading the data from database.
118+
- `buffer.max_retry_interval`: The maximum interval between retries when loading the data from database.
119+
- `buffer.explorer_input.taskset`: The configuration of the taskset.
120+
- `buffer.explorer_input.taskset.name`: The name of the taskset.
121+
- `buffer.explorer_input.taskset.path`: The path to the taskset.
122+
- `buffer.explorer_input.taskset.split`: The split name of the taskset used for training. Default is `train`.
123+
- `buffer.explorer_input.taskset.format`: The format of the taskset. It includes `prompt_key`, `response_key`, `workflow_key` and `reward_fn_key`.
124+
- `buffer.explorer_input.eval_tasksets`: The configuration of the eval tasksets. It is a list of tasksets which will be used for evaluation. And it is empty by default.
125+
- `buffer.explorer_input.default_workflow_type`: The default workflow type for `taskset` and `eval_tasksets`.
126+
- `buffer.explorer_input.default_reward_fn_type`: The default reward function type for `taskset` and `eval_tasksets`.
127+
- `buffer.trainer_input.experience_buffer`: The configuration of experience_buffer.
128+
- `buffer.trainer_input.experience_buffer.name`: The name of the experience buffer.
129+
- `buffer.trainer_input.experience_buffer.storage_type`: The storage type of the experience buffer. Default is `queue`.
130+
- `buffer.trainer_input.experience_buffer.path`: The sql path to store the experience buffer. It can be empty to indicate not saving to the database.
131+
- `buffer.trainer_input.sft_warmup_dataset`: The configuration of the SFT warmup dataset. The structure of `sft_warmup_dataset` is the similar to `buffer.explorer_input.taskset`.
108132

109133
## Explorer
110134

@@ -157,7 +181,7 @@ synchronizer:
157181
- `synchronizer.sync_method`: The synchronization method between `trainer` and `explorer`.
158182
Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explorer` will be synchronized from `trainer` through `nccl`,
159183
`checkpoint` represents that `explorer` will load the newest checkpoints saved by `trainer` then update its model weights. Default is `nccl`.
160-
- `synchronizer.sync_interval`: The interval between two synchronizations. Default is `10`. It should be set manually.
184+
- `synchronizer.sync_interval`: The interval steps between two synchronizations. Default is `10`. It should be set manually.
161185
- `synchronizer.sync_timeout`: The timeout of the synchronization. Default is `1200`.
162186

163187
## Trainer
@@ -176,8 +200,8 @@ trainer:
176200
- `trainer.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
177201
- `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually.
178202
- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`.
179-
- `trainer.eval_interval`: The interval between two evaluations. Default is `1000`.
180-
- `trainer.save_interval`: The interval between two checkpoints. Default is `100`.
203+
- `trainer.eval_interval`: The interval steps between two evaluations. Default is `1000`.
204+
- `trainer.save_interval`: The interval steps between two checkpoints. Default is `100`.
181205

182206
### veRL Trainer Configuration
183207

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,18 @@ class ExampleWorkflow(Workflow):
102102

103103
### Step 3: Modify Configuration File
104104

105-
After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `data` domain to the newly registered `Workflow` name.
105+
After completing the development of the `Workflow`, you need to modify the configuration file to set the `default_workflow_type` in the `buffer.explorer_input` domain to the newly registered `Workflow` name.
106106

107107
```yaml
108-
data:
109-
# Other fields
110-
default_workflow_type: example_workflow
108+
buffer:
111109
# Other fields
110+
explorer_input:
111+
taskset:
112+
name: taskset_name
113+
path: 'path/to/taskset'
114+
# Other fields
115+
eval_tasksets: []
116+
default_workflow_type: example_workflow
112117
# Other fields
113118
```
114119

0 commit comments

Comments
 (0)