Skip to content

Commit ff919db

Browse files
authored
UPD on config_manager.py and docs (#4)
1 parent 638d335 commit ff919db

File tree

12 files changed

+852
-508
lines changed

12 files changed

+852
-508
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
The following is the main config file for Trinity-RFT. Take `scripts/config/countdown.yaml` as an example.
44

5+
6+
## Monitor
7+
8+
```yaml
9+
monitor:
10+
project: "Trinity-RFT-countdown"
11+
name: "qwen2.5-1.5B-countdown"
12+
```
13+
14+
- `monitor.project`: The project name. It must be set manually.
15+
- `monitor.name`: The name of the experiment. It must be set manually.
16+
517
## Data
618

719
<!-- The `data` configuration specifies the data used for training. It includes the total number of epochs, the batch size, the path to the dataset, the default workflow type, the default reward function type, and the format configuration. -->
@@ -53,15 +65,13 @@ model:
5365
max_prompt_tokens: 256
5466
max_response_tokens: 1024
5567
checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown'
56-
load_checkpoint: true
5768
```
5869

5970
- `model.model_path`: The path to the model checkpoint. It must be set manually.
6071
- `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`.
6172
- `model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
6273
- `model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
6374
- `model.checkpoint_path`: The path to the checkpoint of the model. It must be set manually.
64-
- `model.load_checkpoint`: Whether to load the checkpoint of the model. Default is `true`.
6575

6676
## Cluster
6777

@@ -149,19 +159,6 @@ synchronizer:
149159
- `synchronizer.sync_method`: The synchronization method, Support `online` and `offline`. Default is `online`.
150160
- `synchronizer.sync_iteration_interval`: The interval between two synchronizations. Default is `10`. It should be set manually.
151161

152-
## Monitor
153-
154-
```yaml
155-
monitor:
156-
cache_root_dir: ""
157-
project: "Trinity-RFT-countdown"
158-
name: "qwen2.5-1.5B-countdown"
159-
```
160-
161-
- `monitor.cache_root_dir`: The root directory of the cache. Default is `os.path.join(model.checkpoint_path, ".cache")`.
162-
- `monitor.project`: The project name. It must be set manually.
163-
- `monitor.name`: The name of the experiment. It must be set manually.
164-
165162
## Trainer
166163

167164
```yaml
@@ -386,6 +383,7 @@ trainer:
386383
- `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`: Batch size for one GPU in one forward pass.
387384
- `actor_rollout_ref.actor.grad_clip`: Gradient clip for actor model training.
388385
- `actor_rollout_ref.actor.clip_ratio`: Used for compute policy loss.
386+
- `actor_rollout_ref.actor.entropy_coeff`: Used for compute policy loss.
389387
- `actor_rollout_ref.actor.use_kl_loss`: True for GRPO.
390388
- `actor_rollout_ref.actor.kl_loss_coef`: Used for GRPO, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
391389
- `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size.
@@ -412,6 +410,7 @@ trainer:
412410

413411
- `algorithm`: Training algorithm settings.
414412

413+
- `trainer.balance_batch`: Whether to balance batch size between GPUs during training.
415414
- `trainer.save_freq`: Frequency of saving checkpoints.
416415
- `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`.
417416
- `trainer.resume_from_path`: Path to resume from.

scripts/config/alfworld.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ model:
1212
max_prompt_tokens: 4096
1313
max_response_tokens: 16384
1414
checkpoint_path: 'checkpoints/ALFWORLD_RFT'
15-
load_checkpoint: true
1615
cluster:
1716
node_num: 1
1817
gpu_per_node: 8

scripts/config/countdown.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ model:
1414
max_prompt_tokens: 256
1515
max_response_tokens: 1024
1616
checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown'
17-
load_checkpoint: true
1817
cluster:
1918
node_num: 1
2019
gpu_per_node: 8

scripts/config/dpo.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ model:
1313
max_prompt_tokens: 1792
1414
max_response_tokens: 256
1515
checkpoint_path: 'checkpoints/trinity_dpo'
16-
load_checkpoint: true
1716
cluster:
1817
node_num: 1
1918
gpu_per_node: 8

scripts/config/gsm8k.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ model:
2525
max_prompt_tokens: 256
2626
max_response_tokens: 1024
2727
checkpoint_path: '/PATH/TO/CHECKPOINT/'
28-
load_checkpoint: true
2928
cluster:
3029
node_num: 1
3130
gpu_per_node: 8

scripts/config/gsm8k_opmd.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ model:
1111
max_prompt_tokens: 256
1212
max_response_tokens: 1024
1313
checkpoint_path: '{path to checkpoints}/test-opmd-gsm8k/qwen2.5-1.5B-gsm8k-opmd-kl_0.001-entropy_0-tau_4-beta1_0.0-beta2_0.95-lr_2e-6-sync10'
14-
load_checkpoint: false
1514
cluster:
1615
node_num: 1
1716
gpu_per_node: 8

scripts/config/webshop.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ model:
1212
max_prompt_tokens: 4096
1313
max_response_tokens: 16384
1414
checkpoint_path: 'checkpoints/WEBSHOP_RFT'
15-
load_checkpoint: true
1615
cluster:
1716
node_num: 1
1817
gpu_per_node: 8

tests/common/tmp/template_config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ model:
2828
max_prompt_tokens: 2048
2929
max_response_tokens: 2048
3030
checkpoint_path: ''
31-
load_checkpoint: true
3231
cluster:
3332
node_num: 1
3433
gpu_per_node: 8
@@ -61,7 +60,6 @@ trainer:
6160
trainer_config_path: tests/common/tmp/template_verl_config.yaml
6261
monitor:
6362
project: unittest
64-
group: test
6563
name: test
6664
synchronizer:
6765
sync_method: offline

tests/test_data/template.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ model:
55
max_prompt_tokens: 2048
66
max_response_tokens: 2048
77
checkpoint_path: ''
8-
load_checkpoint: true
98
cluster:
109
node_num: 1
1110
gpu_per_node: 8

trinity/common/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ class ModelConfig:
8383
max_response_tokens: int = 2048
8484
# The checkpoint directory, contains a latest dir link and multiple checkpoint dirs.
8585
checkpoint_path: str = ""
86-
load_checkpoint: bool = True
8786

8887

8988
@dataclass
@@ -201,8 +200,6 @@ class MonitorConfig:
201200
# TODO: add more
202201
project: str = "trinity"
203202
name: str = "rft"
204-
group: str = ""
205-
run_id: str = ""
206203

207204
# ! DO NOT SET
208205
# the root directory for cache and meta files, automatically generated

0 commit comments

Comments
 (0)