Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ buffer:
prompt_key: <prompt_key>
chosen_key: <chosen_key>
rejected_key: <rejected_key>
trainer:
global_config:
algorithm_type: dpo

# In train_dpo.yaml
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ The following is the main config file for Trinity-RFT. Take `countdown.yaml` as
```yaml
mode: both
global_config:
algorithm_type: ppo
total_epochs: 1
batch_size: 96
eval_interval: 1000
eval_on_latest_ckp: true
```

- `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`.
- `global_config.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
- `global_config.total_epochs`: The total number of epochs. It should be checked manually.
- `global_config.batch_size`: The batch size used for training. It should be checked manually.
- `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`.
Expand Down Expand Up @@ -192,15 +194,13 @@ Support `nccl` and `checkpoint`, `nccl` represents that model weights in `explor
```yaml
trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
sft_warmup_steps: 0
eval_interval: 1000
save_interval: 100
```

- `trainer.trainer_type`: The backend of the trainer, Only `verl` is supported.
- `trainer.algorithm_type`: The type of the algorithm, Support `ppo`, `grpo`, `opmd` and `dpo`.
- `trainer.trainer_config_path`: The path to the trainer configuration file. It must be set manually.
- `trainer.sft_warmup_steps`: The number of steps to warm up the model. Default is `0`.
- `trainer.eval_interval`: The interval steps between two evaluations. Default is `1000`.
Expand Down
2 changes: 1 addition & 1 deletion examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ global_config:
total_epochs: 20
batch_size: 96
eval_interval: 10
algorithm_type: grpo
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
Expand Down Expand Up @@ -51,7 +52,6 @@ synchronizer:
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: grpo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_steps: 0 # Set to integer to enable sft warmup
monitor:
Expand Down
2 changes: 1 addition & 1 deletion examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ global_config:
total_epochs: 20
batch_size: 96
eval_interval: 10
algorithm_type: grpo
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
Expand Down Expand Up @@ -50,7 +51,6 @@ synchronizer:
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: grpo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_steps: 0 # Set to integer to enable sft warmup
monitor:
Expand Down
2 changes: 1 addition & 1 deletion examples/dpo_humanlike/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mode: train
global_config:
total_epochs: 20
batch_size: 32 # NOTE
algorithm_type: dpo
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/' # NOTE
max_prompt_tokens: 1792
Expand Down Expand Up @@ -29,7 +30,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: dpo
trainer_config_path: 'examples/dpo_humanlike/train_dpo.yaml'
save_interval: 30
monitor:
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
global_config:
total_epochs: 20
batch_size: 4
algorithm_type: grpo
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
max_prompt_tokens: 4096
Expand Down Expand Up @@ -50,7 +51,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: grpo
trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml'
save_interval: 10
monitor:
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ global_config:
total_epochs: 1
batch_size: 96
eval_interval: 50
algorithm_type: grpo
model:
model_path: '/PATH/TO/MODEL/'
max_prompt_tokens: 256
Expand Down Expand Up @@ -80,7 +81,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: grpo
trainer_config_path: 'examples/grpo_gsm8k/train_gsm8k.yaml'
sft_warmup_steps: 0 # Set to integer to enable sft warmup
save_interval: 100
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_math/math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ global_config:
total_epochs: 20
batch_size: 288
eval_interval: 10
algorithm_type: grpo
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 1024
Expand Down Expand Up @@ -50,7 +51,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: grpo
trainer_config_path: 'examples/grpo_math/train_math.yaml'
sft_warmup_steps: 0 # Set to integer to enable sft warmup
save_interval: 100
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_sciworld/sciworld.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
global_config:
total_epochs: 20
batch_size: 4
algorithm_type: grpo
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
max_prompt_tokens: 4096
Expand Down Expand Up @@ -50,7 +51,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: grpo
trainer_config_path: 'examples/grpo_sciworld/train_sciworld.yaml'
save_interval: 10
monitor:
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_webshop/webshop.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
global_config:
total_epochs: 20
batch_size: 4
algorithm_type: grpo
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
max_prompt_tokens: 4096
Expand Down Expand Up @@ -50,7 +51,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: grpo
trainer_config_path: 'examples/grpo_webshop/train_webshop.yaml'
save_interval: 10
monitor:
Expand Down
2 changes: 1 addition & 1 deletion examples/opmd_gsm8k/opmd_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
global_config:
total_epochs: 1
batch_size: 96
algorithm_type: opmd
model:
model_path: '{path to models}/Qwen2.5-1.5B-Inst'
max_prompt_tokens: 256
Expand Down Expand Up @@ -49,7 +50,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: opmd
trainer_config_path: 'examples/opmd_gsm8k/train_opmd_gsm8k.yaml'
sft_warmup_steps: 0
save_interval: 100
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_countdown/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ global_config:
total_epochs: 20
batch_size: 96
eval_interval: 1000
algorithm_type: ppo
model:
model_path: '/PATH/TO/MODEL/CHECKPOINT/'
max_prompt_tokens: 256
Expand Down Expand Up @@ -51,7 +52,6 @@ synchronizer:
sync_timeout: 1200
trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
sft_warmup_steps: 0
save_interval: 100
Expand Down
4 changes: 2 additions & 2 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def train(config: Config) -> None:
logger.info("SFT warmup finished.")
break

algo_type = config.trainer.algorithm_type
algo_type = config.global_config.algorithm_type
try:
ray.get(trainer.train.remote(algo_type))
logger.info("Train finished.")
Expand Down Expand Up @@ -100,7 +100,7 @@ def both(config: Config) -> None:
break
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])

algo_type = config.trainer.algorithm_type
algo_type = config.global_config.algorithm_type
while True:
try:
ref_explore = explorer.explore_one_period.remote()
Expand Down
33 changes: 19 additions & 14 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class GlobalConfig:
batch_size: int = 1
eval_interval: int = 100
eval_on_latest_ckp: bool = True
algorithm_type: AlgorithmType = AlgorithmType.PPO


@dataclass
Expand Down Expand Up @@ -227,7 +228,6 @@ class TrainerConfig:
trainer_config: Any = field(default_factory=dict)

# train algorithm
algorithm_type: AlgorithmType = AlgorithmType.PPO
get_exp_strategy: Optional[str] = None

# warmup config
Expand Down Expand Up @@ -309,7 +309,7 @@ def _check_interval(self) -> None:
# check eval_interval
if (
self.mode != "bench"
and self.trainer.algorithm_type != AlgorithmType.DPO
and self.global_config.algorithm_type != AlgorithmType.DPO
and self.global_config.eval_interval % self.synchronizer.sync_interval != 0
):
self.global_config.eval_interval = (
Expand All @@ -322,12 +322,12 @@ def _check_interval(self) -> None:
# check save_interval
if (
self.mode != "bench"
and self.trainer.algorithm_type != AlgorithmType.DPO
and self.global_config.algorithm_type != AlgorithmType.DPO
and self.synchronizer.sync_method == SyncMethod.CHECKPOINT
):
if self.trainer.save_interval != self.synchronizer.sync_interval:
logger.warning(
f"When `trainer.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, "
f"When `global_config.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, "
f"`trainer.save_interval` will be set to "
f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`."
)
Expand Down Expand Up @@ -390,20 +390,24 @@ def _check_buffer(self) -> None: # noqa: C901
f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}"
)
elif self.mode == "train": # TODO: to be check
if self.trainer.algorithm_type.is_dpo():
if self.global_config.algorithm_type.is_dpo():
if (
self.buffer.trainer_input.experience_buffer is None
or not self.buffer.trainer_input.experience_buffer.path
):
raise ValueError(
"`buffer.trainer_input.experience_buffer.path` is required when `trainer.algorithm_type == AlgorithmType.DPO`"
"`buffer.trainer_input.experience_buffer.path` is required when `global_config.algorithm_type == AlgorithmType.DPO`"
)
if self.mode in ["both", "train"]:
self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type
if self.buffer.trainer_input.experience_buffer is not None:
self.buffer.trainer_input.experience_buffer.algorithm_type = (
self.global_config.algorithm_type
)

# set buffer.explorer_output
if self.buffer.explorer_output is None:
self.buffer.explorer_output = self.buffer.trainer_input.experience_buffer
else:
self.buffer.explorer_output.algorithm_type = self.global_config.algorithm_type

# check trainer_input.sft_warmup_dataset
if (
Expand Down Expand Up @@ -440,7 +444,7 @@ def check_and_update(self) -> None: # noqa: C901
# check mode
if self.mode not in ["explore", "train", "both", "bench"]:
raise ValueError(f"Invalid mode: {self.mode}")
if self.trainer.algorithm_type == AlgorithmType.DPO and self.mode == "both":
if self.global_config.algorithm_type == AlgorithmType.DPO and self.mode == "both":
raise ValueError("DPO does not support `both` mode")

# check model path
Expand All @@ -454,21 +458,22 @@ def check_and_update(self) -> None: # noqa: C901
self.explorer.engine_num * self.explorer.tensor_parallel_size
)
self.synchronizer.backend = self.explorer.backend
if self.mode == "bench" and self.synchronizer.sync_method != SyncMethod.CHECKPOINT:
if (
self.mode in ["train", "explore", "bench"]
and self.synchronizer.sync_method != SyncMethod.CHECKPOINT
):
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
"Bench mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
f"`{self.mode}` mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
if (
self.trainer.algorithm_type == AlgorithmType.DPO
self.global_config.algorithm_type == AlgorithmType.DPO
and self.synchronizer.sync_method != SyncMethod.CHECKPOINT
):
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
"DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
if self.synchronizer.sync_method == SyncMethod.NCCL and self.mode != "both":
raise ValueError("`nccl` synchronization is only supported in both mode.")

self._check_interval()

Expand Down
6 changes: 3 additions & 3 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,11 @@ def synchronize_config(self, config: Config) -> None:
self.critic.ppo_mini_batch_size = config.global_config.batch_size
self.critic.rollout_n = self.actor_rollout_ref.rollout.n

self.actor_rollout_ref.actor.algorithm_type = config.trainer.algorithm_type
if config.trainer.algorithm_type == AlgorithmType.PPO:
self.actor_rollout_ref.actor.algorithm_type = config.global_config.algorithm_type
if config.global_config.algorithm_type == AlgorithmType.PPO:
logger.info("Using GAE `adv_estimator` for PPO")
self.algorithm.adv_estimator = AdvantageEstimator.GAE.value
elif config.trainer.algorithm_type == AlgorithmType.GRPO:
elif config.global_config.algorithm_type == AlgorithmType.GRPO:
logger.info("Using GRPO `adv_estimator` for GRPO")
self.algorithm.adv_estimator = AdvantageEstimator.GRPO.value

Expand Down
Loading