Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/sphinx_doc/source/tutorial/example_data_functionalities.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
total_epoch: 1
total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'
```
Expand All @@ -53,7 +53,7 @@ Here you can set the basic information for the GSM-8K dataset, database informat
+ `dataset_config`: extra config arguments for loading the raw dataset. Mainly for the `load_dataset` method in HuggingFace `datasets` library.
+ `format_config`: some dataset format config items, which are used to map original data field names to unified ones.
+ `db_url`: the URL of the postgresql database to store the result dataset.
+ `total_epoch`: the total number of epochs to train on this dataset.
+ `total_epochs`: the total number of epochs to train on this dataset.
+ `batch_size`: the training batch size.
+ `default_workflow_type`: the default exploring workflow type. Please refer to [programming guide](trinity_programming_guide.md) for more details.

Expand All @@ -74,7 +74,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
total_epoch: 1
total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'

Expand Down Expand Up @@ -120,7 +120,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
total_epoch: 1
total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'

Expand Down Expand Up @@ -199,7 +199,7 @@ data:
# database related. The result dataset will be stored in the database.
db_url: 'postgresql://{user_name}@localhost:5432/{db_name}'
# downstream loading related
total_epoch: 20
total_epochs: 20
batch_size: 32
default_workflow_type: 'math_workflow'
```
Expand Down
27 changes: 22 additions & 5 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
The following is the main config file for Trinity-RFT. Take `countdown.yaml` as an example.


## Monitor

```yaml
monitor:
project: "Trinity-RFT-countdown"
name: "qwen2.5-1.5B-countdown"
```

- `monitor.project`: The project name. It must be set manually.
- `monitor.name`: The name of the experiment. It must be set manually.


## Monitor

```yaml
Expand Down Expand Up @@ -33,7 +45,7 @@ data:
max_retry_times: 3
max_retry_interval: 1

total_epoch: 20
total_epochs: 20
batch_size: 96
default_workflow_type: 'math_workflow'
default_reward_fn_type: 'countdown_reward'
Expand All @@ -47,7 +59,7 @@ data:
- `data.db_url`: The URL of the database.
- `data.max_retry_times`: The maximum number of retries when loading the dataset from database.
- `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database.
- `data.total_epoch`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
- `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `actor_rollout_ref.rollout.n` Default is `1`. It should be set manually.
- `data.default_workflow_type`: The default workflow type used for training.
- `data.default_reward_fn_type`: The default reward function type used for training.
Expand Down Expand Up @@ -345,10 +357,14 @@ algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
norm_adv_by_std_in_grpo: True
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
horizon: 10000
target_kl: 0.1

trainer:
balance_batch: True
Expand All @@ -363,7 +379,7 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
resume_from_path: ""
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand All @@ -383,8 +399,9 @@ trainer:
- `actor_rollout_ref.actor.grad_clip`: Gradient clip for actor model training.
- `actor_rollout_ref.actor.clip_ratio`: Used for compute policy loss.
- `actor_rollout_ref.actor.entropy_coeff`: Used for compute policy loss.
- `actor_rollout_ref.actor.use_kl_loss`: True for GRPO.
- `actor_rollout_ref.actor.kl_loss_coef`: Used for GRPO, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
- `actor_rollout_ref.actor.use_kl_loss`: Whether to enable kl loss.
- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss.
- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
- `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size.
- `actor_rollout_ref.actor.alg_type`: Used for OPMD, optional value is `ppo`, `opmd` or `pairwise_opmd`.
- `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy.
Expand Down
3 changes: 1 addition & 2 deletions examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mode: train
data:
total_epoch: 20
total_epochs: 20
batch_size: 32 # NOTE
train_split: "train"
dataset_path: ''
Expand All @@ -22,7 +22,6 @@ buffer:
train_dataset:
name: dpo_buffer
storage_type: file
algorithm_type: dpo
path: '/PATH/TO/DATASET/'
kwargs:
prompt_type: plaintext # plaintext/messages
Expand Down
1 change: 0 additions & 1 deletion examples/dpo_humanlike/train_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ trainer:
save_freq: 30
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
Expand Down
3 changes: 1 addition & 2 deletions examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data:
total_epoch: 20
total_epochs: 20
batch_size: 4
dataset_path: 'scripts/data_prepare/alfworld_data'
default_workflow_type: 'alfworld_workflow'
Expand All @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: alfworld_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///alfworld.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_alfworld/train_alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
4 changes: 1 addition & 3 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ data:
# db related
db_url: ''
# downstream loading related
total_epoch: 1
total_epochs: 1
batch_size: 96
default_workflow_type: 'math_workflow'
model:
Expand All @@ -35,12 +35,10 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///gsm8k.db'
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
# algorithm_type: sft
# path: '/PATH/TO/WARMUP_DATA/'
# kwargs:
# prompt_type: plaintext
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_gsm8k/train_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
Expand Down
5 changes: 2 additions & 3 deletions examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data:
# db related
db_url: ''
# downstream loading related
total_epoch: 20
total_epochs: 20
batch_size: 288
default_workflow_type: 'math_workflow'
model:
Expand All @@ -27,8 +27,7 @@ buffer:
train_dataset:
name: math_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////math.db'
path: 'sqlite:///math.db'
explorer:
engine_type: vllm_async
engine_num: 2
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_math/train_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 5
critic_warmup: 0
default_hdfs_dir: null
Expand Down
3 changes: 1 addition & 2 deletions examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data:
total_epoch: 20
total_epochs: 20
batch_size: 4
dataset_path: 'scripts/data_prepare/sciworld_data'
default_workflow_type: 'sciworld_workflow'
Expand All @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: sciworld_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///sciworld.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_sciworld/train_sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 0 additions & 1 deletion examples/grpo_webshop/train_webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ trainer:
save_freq: 1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
3 changes: 1 addition & 2 deletions examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data:
total_epoch: 20
total_epochs: 20
batch_size: 4
dataset_path: 'scripts/data_prepare/webshop_data'
default_workflow_type: 'webshop_workflow'
Expand All @@ -21,7 +21,6 @@ buffer:
train_dataset:
name: webshop_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:///webshop.db'
explorer:
engine_type: vllm_async
Expand Down
3 changes: 1 addition & 2 deletions examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data:
total_epoch: 1
total_epochs: 1
batch_size: 96
dataset_path: '{path to datasets}/gsm8k'
default_workflow_type: 'math_workflow'
Expand All @@ -20,7 +20,6 @@ buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
algorithm_type: opmd
path: 'sqlite:///gsm8k_opmd.db'
explorer:
engine_type: vllm_async
Expand Down
1 change: 0 additions & 1 deletion examples/opmd_gsm8k/train_opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
5 changes: 2 additions & 3 deletions examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data:
total_epoch: 20
total_epochs: 20
batch_size: 96
dataset_path: 'countdown_dataset/oneshot-split'
default_workflow_type: 'math_workflow'
Expand All @@ -23,8 +23,7 @@ buffer:
train_dataset:
name: countdown_buffer
storage_type: queue
algorithm_type: ppo
path: 'sqlite:////countdown.db'
path: 'sqlite:///countdown.db'
explorer:
engine_type: vllm_async
engine_num: 2
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_countdown/train_countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ trainer:
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: 100
critic_warmup: 0
default_hdfs_dir: null
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"math_verify",
"ninja",
"fire",
"streamlit",
"flask",
"requests",
"tensorboard",
Expand Down
2 changes: 1 addition & 1 deletion tests/common/tmp/template_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mode: both
data:
dataset_path: ''
total_epoch: 1
total_epochs: 1
batch_size: 32
train_split: 'train'
eval_split: ''
Expand Down
Loading