Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Note that the dataset has the keys `prompt`, `chosen` and `rejected`. If not, pa

We use the configurations in [`dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/dpo.yaml) and [`train_dpo.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/dpo_humanlike/train_dpo.yaml) for this experiment. Some important setups are listed in the following:

We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`. The value of `sync_iteration_interval` can be set as same of the value of `save_interval`.
We run the experiment in a train mode, as there is no Explorer. To enable this mode, we config `mode` to `train` and set `sync_method` to `checkpoint`.

```yaml
# In dpo.yaml
Expand All @@ -50,7 +50,6 @@ synchronizer:
buffer:
train_dataset:
storage_type: file
algorithm_type: dpo
path: <$DATASET_PATH/human_like_dpo_dataset>
kwargs:
prompt_type: <prompt_type> # messages/plaintext
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ To try out the OPMD algorithm:
trinity run --config examples/opmd_gsm8k/opmd_gsm8k.yaml
```

Note that in this config file, `sync_iteration_interval` is set to 10, i.e., the model weights of explorer and trainer are synchronized only once every 10 training steps, which leads to a challenging off-policy scenario (potentially with abrupt distribution shift during the RFT process).
Note that in this config file, `sync_interval` is set to 10, i.e., the model weights of explorer and trainer are synchronized only once every 10 training steps, which leads to a challenging off-policy scenario (potentially with abrupt distribution shift during the RFT process).
Other configurations of particular interest are explained at the beginning of [`train_opmd_gsm8k.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/opmd_gsm8k/train_opmd_gsm8k.yaml).


Expand All @@ -30,7 +30,7 @@ Other configurations of particular interest are explained at the beginning of [`
The red curve below shows an example of OPMD's learning curves.
Since the explorer's model weights remain unchanged for the first 10 steps, its score remains flat.
Then, after the model weights of explorer and trainer are synchronized at the end of step 10, we see an abrupt increase in score at step 11, which indicates effective off-policy learning in the first 10 steps.
A similar performance boost is shown at step 21, which leads to a converged score matching what is achieved by GRPO in a mostly on-policy case (with `sync_iteration_interval=2`).
A similar performance boost is shown at step 21, which leads to a converged score matching what is achieved by GRPO in a mostly on-policy case (with `sync_interval=2`).



Expand Down
9 changes: 4 additions & 5 deletions docs/sphinx_doc/source/tutorial/example_reasoning_basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ More details on dataset downloading are referred to [ModelScope](https://modelsc

### Synchronous Mode of Trinity-RFT

We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_iteration_interval` properly. A smaller value of `sync_iteration_interval` makes the training closer to an on-policy setup.
We run the experiment in a synchronous mode where the Explorer and Trainer operate in turn. To enable this mode, we config `mode` to `both` (default) and set `sync_interval` properly. A smaller value of `sync_interval` makes the training closer to an on-policy setup.

```yaml
mode: both
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 2
sync_interval: 2
```

### Use GRPO or PPO Algorithm
Expand Down Expand Up @@ -76,21 +76,20 @@ trinity run --config examples/grpo_gsm8k/gsm8k.yaml

## Optional: RFT with SFT Warmup

Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_iteration > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`.
Before RFT, we may use SFT as a warmup step. We need to set `trainer.sft_warmup_steps > 0` and prepare the SFT data to `buffer.train_dataset.path=$DATASET_PATH/{sft_data}`.

```yaml
# Properly set the following configs in gsm8k.yaml
buffer:
sft_warmup_dataset:
storage_type: file
algorithm_type: sft
path: <$DATASET_PATH/{sft_data}>
kwargs:
prompt_type: <prompt_type> # messages/plaintext/chatpair
prompt_key: <prompt_key>
response_key: <response_key>
trainer:
sft_warmup_iteration: 10
sft_warmup_steps: 10
```

The following command runs SFT and RFT in sequence:
Expand Down
12 changes: 6 additions & 6 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ data:
- `data.max_retry_times`: The maximum number of retries when loading the dataset from database.
- `data.max_retry_interval`: The maximum interval between retries when loading the dataset from database.
- `data.total_epochs`: The total number of epochs to explore the dataset. Default is `1`. It should be set manually.
- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `actor_rollout_ref.rollout.n` Default is `1`. It should be set manually.
- `data.batch_size`: The number of `Task` in one training batch. The real batch size used in training is `data.batch_size` * `explorer.repeat_times`. It should be set manually.
- `data.default_workflow_type`: The default workflow type used for training.
- `data.default_reward_fn_type`: The default reward function type used for training.

Expand Down Expand Up @@ -150,14 +150,14 @@ explorer:
```yaml
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 10
sync_interval: 10
sync_timeout: 1200
```

- `synchronizer.sync_method`: The synchronization method between `trainer` and `explorer`.
Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explorer` will be synchronized from `trainer` through `nccl`,
`checkpoint` represents that `explorer` will load the newest checkpoints saved by `trainer` then update its model weights. Default is `nccl`.
- `synchronizer.sync_iteration_interval`: The interval between two synchronizations. Default is `10`. It should be set manually.
- `synchronizer.sync_interval`: The interval between two synchronizations. Default is `10`. It should be set manually.
- `synchronizer.sync_timeout`: The timeout of the synchronization. Default is `1200`.

## Trainer
Expand All @@ -167,15 +167,15 @@ trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
sft_warmup_iteration: 0
sft_warmup_steps: 0
eval_interval: 1000
save_interval: 100
```

- `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported.
- `trainer.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
- `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually.
- `trainer.sft_warmup_iteration`: The number of iterations to warm up the model. Default is `0`.
- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`.
- `trainer.eval_interval`: The interval between two evaluations. Default is `1000`.
- `trainer.save_interval`: The interval between two checkpoints. Default is `100`.

Expand Down Expand Up @@ -418,7 +418,7 @@ trainer:
- `trainer.balance_batch`: Whether to balance batch size between GPUs during training.
- `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`.
- `trainer.resume_from_path`: Path to resume from.
- `trainer.critic_warmup`: The number of iteration to train the critic model before actual policy learning.
- `trainer.critic_warmup`: The number of steps to train the critic model before actual policy learning.
- `trainer.default_hdfs_dir`: Default HDFS directory for saving checkpoints.
- `trainer.remove_previous_ckpt_in_save`: Whether to remove previous checkpoints in save.
- `trainer.del_local_ckpt_after_load`: Whether to delete local checkpoints after loading.
Expand Down
4 changes: 2 additions & 2 deletions examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ synchronizer:
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: ppo
algorithm_type: grpo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
sft_warmup_steps: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
Expand Down
4 changes: 2 additions & 2 deletions examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ synchronizer:
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: ppo
algorithm_type: grpo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
sft_warmup_steps: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
Expand Down
2 changes: 1 addition & 1 deletion examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ explorer:
max_waiting_steps: 4
synchronizer:
sync_method: 'checkpoint'
sync_iteration_interval: 30
sync_interval: 30
sync_timeout: 1200
trainer:
trainer_type: 'verl'
Expand Down
6 changes: 3 additions & 3 deletions examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ explorer:
max_pending_requests: 32
max_waiting_steps: 4
gpu_memory_utilization: 0.7
enable_chunked_prefil: true
enable_chunked_prefill: true
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 8
sync_interval: 8
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: ppo
algorithm_type: grpo
trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml'
save_interval: 10
monitor:
Expand Down
6 changes: 3 additions & 3 deletions examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ explorer:
max_waiting_steps: 4
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 2
sync_interval: 2
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: ppo
algorithm_type: grpo
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
sft_warmup_steps: 0 # Set to integer to enable sft warmup
eval_interval: 50
save_interval: 100
# get_exp_strategy: 'LFU'
Expand Down
6 changes: 3 additions & 3 deletions examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ explorer:
max_waiting_steps: 4
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 2
sync_interval: 2
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: ppo
algorithm_type: grpo
trainer_config_path: 'examples/grpo_math/train_math.yaml'
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
sft_warmup_steps: 0 # Set to integer to enable sft warmup
eval_interval: 10
save_interval: 100
monitor:
Expand Down
6 changes: 3 additions & 3 deletions examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ explorer:
max_pending_requests: 32
max_waiting_steps: 4
gpu_memory_utilization: 0.7
enable_chunked_prefil: true
enable_chunked_prefill: true
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 8
sync_interval: 8
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: ppo
algorithm_type: grpo
trainer_config_path: 'examples/grpo_sciworld/train_sciworld.yaml'
save_interval: 10
monitor:
Expand Down
6 changes: 3 additions & 3 deletions examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ explorer:
max_pending_requests: 32
max_waiting_steps: 4
gpu_memory_utilization: 0.7
enable_chunked_prefil: true
enable_chunked_prefill: true
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 8
sync_interval: 8
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: ppo
algorithm_type: grpo
trainer_config_path: 'examples/grpo_webshop/train_webshop.yaml'
save_interval: 10
monitor:
Expand Down
4 changes: 2 additions & 2 deletions examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ explorer:
max_waiting_steps: 4
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 10
sync_interval: 10
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: opmd
trainer_config_path: 'examples/opmd_gsm8k/train_opmd_gsm8k.yaml'
sft_warmup_iteration: 0
sft_warmup_steps: 0
save_interval: 100
monitor:
cache_root_dir: ""
Expand Down
4 changes: 2 additions & 2 deletions examples/opmd_gsm8k/train_opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ actor_rollout_ref:
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_iteration_interval)
beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_iteration_interval)
beta1: 0.0 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
beta2: 0.95 # set to smaller value for scenarios with abrupt distribution shift (e.g., large sync_interval)
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ explorer:
max_waiting_steps: 4
synchronizer:
sync_method: 'nccl'
sync_iteration_interval: 10
sync_interval: 10
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
sft_warmup_iteration: 0
sft_warmup_steps: 0
eval_interval: 1000
save_interval: 100
monitor:
Expand Down
2 changes: 1 addition & 1 deletion tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_load_default_config(self):
self.assertEqual(config.trainer.trainer_config.trainer.project_name, config.monitor.project)
self.assertEqual(
config.trainer.trainer_config.trainer.save_freq,
config.synchronizer.sync_iteration_interval,
config.synchronizer.sync_interval,
)

def test_all_examples_are_valid(self):
Expand Down
12 changes: 6 additions & 6 deletions tests/data/core/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def test_to_taskset(self):
def test_to_taskset_with_global_settings(self):
dataset = RftDataset(data_config=self.data_config, reward_schema="default")
taskset = dataset.to_taskset(
reward_fn=AccuracyReward(),
reward_fn=AccuracyReward,
workflow=SimpleWorkflow,
)
self.assertIsInstance(taskset, TaskSet)
self.assertEqual(taskset.workflow, SimpleWorkflow)
self.assertIsInstance(taskset.reward_fn, AccuracyReward)
self.assertEqual(taskset.reward_fn, AccuracyReward)

def test_to_taskset_with_sample_level_settings(self):
dataset = RftDataset(
Expand All @@ -97,22 +97,22 @@ def test_to_taskset_with_sample_level_settings(self):
self.assertIsInstance(taskset, TaskSet)
for task in taskset.tasks:
self.assertEqual(task.workflow, MathWorkflow)
self.assertIsInstance(task.reward_fn, AccuracyReward)
self.assertEqual(task.reward_fn, AccuracyReward)

def test_to_taskset_with_both_settings(self):
dataset = RftDataset(
data_config=self.data_config_sample_level_setting, reward_schema="default"
)
taskset = dataset.to_taskset(
reward_fn=AccuracyReward(),
reward_fn=AccuracyReward,
workflow=SimpleWorkflow,
)
self.assertIsInstance(taskset, TaskSet)
for task in taskset.tasks:
self.assertEqual(task.workflow, MathWorkflow)
self.assertIsInstance(task.reward_fn, AccuracyReward)
self.assertEqual(task.reward_fn, AccuracyReward)
self.assertEqual(taskset.workflow, SimpleWorkflow)
self.assertIsInstance(taskset.reward_fn, AccuracyReward)
self.assertEqual(taskset.reward_fn, AccuracyReward)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self):
self.config.monitor.monitor_type = MonitorType.TENSORBOARD
self.config.monitor.project = "Trinity-unittest"
self.config.model.checkpoint_path = get_checkpoint_path()
self.config.synchronizer.sync_iteration_interval = 2
self.config.synchronizer.sync_interval = 2
self.config.explorer.eval_interval = 4
self.config.trainer.eval_interval = 4

Expand Down
4 changes: 2 additions & 2 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ explorer:
trainer:
trainer_type: verl
trainer_config_path: tests/template/verl_config.yaml
sft_warmup_iteration: 0
sft_warmup_steps: 0
eval_interval: 1000
save_interval: 100
monitor:
project: unittest
name: test
synchronizer:
sync_method: checkpoint
sync_iteration_interval: 10
sync_interval: 10
sync_timeout: 1200
wait_for_checkpoint: false
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_checkpoint_path() -> str:


def get_unittest_dataset_config(dataset_name: str = "countdown") -> DataConfig:
"""Countdown sample dataset for 8 iterations"""
"""Countdown sample dataset for 8 steps"""
if dataset_name == "countdown":
return DataConfig(
total_epochs=2,
Expand Down
Loading