Skip to content

Commit d66b3de

Browse files
authored
Refactor on config_manager.py (#23)
1 parent a9c650b commit d66b3de

File tree

25 files changed

+1350
-831
lines changed

25 files changed

+1350
-831
lines changed

docs/sphinx_doc/source/tutorial/example_data_functionalities.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ data:
4242
# database related. The result dataset will be stored in the database.
4343
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
4444
# downstream loading related
45-
total_epoch: 1
45+
total_epochs: 1
4646
batch_size: 96
4747
default_workflow_type: 'math_workflow'
4848
```
@@ -53,7 +53,7 @@ Here you can set the basic information for the GSM-8K dataset, database informat
5353
+ `dataset_config`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library.
5454
+ `format_config`: some dataset format config items, which are used to map original data field names to unified ones.
5555
+ `db_url`: the URL of the postgresql database to store the result dataset.
56-
+ `total_epoch`: the total number of epochs to train on this dataset.
56+
+ `total_epochs`: the total number of epochs to train on this dataset.
5757
+ `batch_size`: the training batch size.
5858
+ `default_workflow_type`: the default exploring workflow type. Please refer to [programming guide](trinity_programming_guide.md) for more details.
5959

@@ -74,7 +74,7 @@ data:
7474
# database related. The result dataset will be stored in the database.
7575
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
7676
# downstream loading related
77-
total_epoch: 1
77+
total_epochs: 1
7878
batch_size: 96
7979
default_workflow_type: 'math_workflow'
8080
@@ -120,7 +120,7 @@ data:
120120
# database related. The result dataset will be stored in the database.
121121
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
122122
# downstream loading related
123-
total_epoch: 1
123+
total_epochs: 1
124124
batch_size: 96
125125
default_workflow_type: 'math_workflow'
126126
@@ -199,7 +199,7 @@ data:
199199
# database related. The result dataset will be stored in the database.
200200
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
201201
# downstream loading related
202-
total_epoch: 20
202+
total_epochs: 20
203203
batch_size: 32
204204
default_workflow_type: 'math_workflow'
205205
```

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33
The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example.
44

55

6+
## Monitor
7+
8+
```yaml
9+
monitor:
10+
project: "Trinity-RFT-countdown"
11+
name: "qwen2.5-1.5B-countdown"
12+
```
13+
14+
- `monitor.project`: The project name. It must be set manually.
15+
- `monitor.name`: The name of the experiment. It must be set manually.
16+
17+
618
## Monitor
719

820
```yaml
@@ -33,7 +45,7 @@ data:
3345
max_retry_times: 3
3446
max_retry_interval: 1
3547
36-
total_epoch: 20
48+
total_epochs: 20
3749
batch_size: 96
3850
default_workflow_type: 'math_workflow'
3951
default_reward_fn_type: 'countdown_reward'
@@ -47,7 +59,7 @@ data:
4759
- `data.db_url`: The URL of the database.
4860
- `data.max_retry_times`: The maximum number of retries when loading the dataset from database.
4961
- `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database.
50-
- `data.total_epoch`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
62+
- `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
5163
- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `actor_rollout_ref.rollout.n` Default is `1`. It should be set manually.
5264
- `data.default_workflow_type`: The default workflow type used for training.
5365
- `data.default_reward_fn_type`: The default reward function type used for training.
@@ -345,10 +357,14 @@ algorithm:
345357
gamma: 1.0
346358
lam: 1.0
347359
adv_estimator: gae
360+
norm_adv_by_std_in_grpo: True
361+
use_kl_in_reward: False
348362
kl_penalty: kl # how to estimate kl divergence
349363
kl_ctrl:
350364
type: fixed
351365
kl_coef: 0.001
366+
horizon: 10000
367+
target_kl: 0.1
352368

353369
trainer:
354370
balance_batch: True
@@ -363,7 +379,7 @@ trainer:
363379
save_freq: 100
364380
# auto: find the last ckpt to resume. If can't find, start from scratch
365381
resume_mode: auto # or auto or resume_path if
366-
resume_from_path: False
382+
resume_from_path: ""
367383
test_freq: 100
368384
critic_warmup: 0
369385
default_hdfs_dir: null
@@ -383,8 +399,9 @@ trainer:
383399
- `actor_rollout_ref.actor.grad_clip`: Gradient clip for actor model training.
384400
- `actor_rollout_ref.actor.clip_ratio`: Used for compute policy loss.
385401
- `actor_rollout_ref.actor.entropy_coeff`: Used for compute policy loss.
386-
- `actor_rollout_ref.actor.use_kl_loss`: True for GRPO.
387-
- `actor_rollout_ref.actor.kl_loss_coef`: Used for GRPO, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
402+
- `actor_rollout_ref.actor.use_kl_loss`: Whether to enable kl loss.
403+
- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss.
404+
- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
388405
- `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size.
389406
- `actor_rollout_ref.actor.alg_type`: Used for OPMD, optional value is `ppo`, `opmd` or `pairwise_opmd`.
390407
- `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy.

examples/dpo_humanlike/dpo.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
mode: train
22
data:
3-
total_epoch: 20
3+
total_epochs: 20
44
batch_size: 32 # NOTE
55
train_split: "train"
66
dataset_path: ''
@@ -22,7 +22,6 @@ buffer:
2222
train_dataset:
2323
name: dpo_buffer
2424
storage_type: file
25-
algorithm_type: dpo
2625
path: '/PATH/TO/DATASET/'
2726
kwargs:
2827
prompt_type: plaintext # plaintext/messages

examples/dpo_humanlike/train_dpo.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ trainer:
173173
save_freq: 30
174174
# auto: find the last ckpt to resume. If can't find, start from scratch
175175
resume_mode: auto # or auto or resume_path if
176-
resume_from_path: False
177176
test_freq: 5
178177
critic_warmup: 0
179178
default_hdfs_dir: null

examples/grpo_alfworld/alfworld.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
data:
2-
total_epoch: 20
2+
total_epochs: 20
33
batch_size: 4
44
dataset_path: 'scripts/data_prepare/alfworld_data'
55
default_workflow_type: 'alfworld_workflow'
@@ -21,7 +21,6 @@ buffer:
2121
train_dataset:
2222
name: alfworld_buffer
2323
storage_type: queue
24-
algorithm_type: ppo
2524
path: 'sqlite:///alfworld.db'
2625
explorer:
2726
engine_type: vllm_async

examples/grpo_alfworld/train_alfworld.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ trainer:
172172
save_freq: 1
173173
# auto: find the last ckpt to resume. If can't find, start from scratch
174174
resume_mode: auto # or auto or resume_path if
175-
resume_from_path: False
176175
test_freq: 100
177176
critic_warmup: 0
178177
default_hdfs_dir: null

examples/grpo_gsm8k/gsm8k.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ data:
1818
# db related
1919
db_url: ''
2020
# downstream loading related
21-
total_epoch: 1
21+
total_epochs: 1
2222
batch_size: 96
2323
default_workflow_type: 'math_workflow'
2424
model:
@@ -35,12 +35,10 @@ buffer:
3535
train_dataset:
3636
name: gsm8k_buffer
3737
storage_type: queue
38-
algorithm_type: ppo
3938
path: 'sqlite:///gsm8k.db'
4039
# sft_warmup_dataset: # Uncomment these to enable sft warmup
4140
# name: warmup_data
4241
# storage_type: file
43-
# algorithm_type: sft
4442
# path: '/PATH/TO/WARMUP_DATA/'
4543
# kwargs:
4644
# prompt_type: plaintext

examples/grpo_gsm8k/train_gsm8k.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ trainer:
177177
save_freq: 100
178178
# auto: find the last ckpt to resume. If can't find, start from scratch
179179
resume_mode: auto # or auto or resume_path if
180-
resume_from_path: False
181180
test_freq: 5
182181
critic_warmup: 0
183182
default_hdfs_dir: null

examples/grpo_math/math.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ data:
1010
# db related
1111
db_url: ''
1212
# downstream loading related
13-
total_epoch: 20
13+
total_epochs: 20
1414
batch_size: 288
1515
default_workflow_type: 'math_workflow'
1616
model:
@@ -27,8 +27,7 @@ buffer:
2727
train_dataset:
2828
name: math_buffer
2929
storage_type: queue
30-
algorithm_type: ppo
31-
path: 'sqlite:////math.db'
30+
path: 'sqlite:///math.db'
3231
explorer:
3332
engine_type: vllm_async
3433
engine_num: 2

examples/grpo_math/train_math.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ trainer:
169169
save_freq: 100
170170
# auto: find the last ckpt to resume. If can't find, start from scratch
171171
resume_mode: auto # or auto or resume_path if
172-
resume_from_path: False
173172
test_freq: 5
174173
critic_warmup: 0
175174
default_hdfs_dir: null

0 commit comments

Comments
 (0)