Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,12 @@ Specifies the backend and behavior of the trainer.
```yaml
trainer:
name: trainer
trainer_type: 'verl'
save_interval: 100
trainer_type: "verl"
trainer_strategy: "fsdp"
total_steps: 1000
save_interval: 100
save_strategy: "unrestricted"
save_hf_checkpoint: "last"
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
Expand All @@ -449,13 +451,21 @@ trainer:

- `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique.
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
- `trainer_strategy`: Strategy for VeRL trainer. Default is `fsdp`. Options include:
- `fsdp`: Use PyTorch FSDP.
- `fsdp2`: Use PyTorch FSDP2.
- `megatron`: Use Megatron-LM.
- `total_steps`: Total number of training steps.
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
- `save_strategy`: The parallel strategy used when saving the model. Defaults to `unrestricted`. The available options are as follows:
- `single_thread`: Only one thread across the entire system is allowed to save the model; saving tasks from different threads are executed sequentially.
- `single_process`: Only one process across the entire system is allowed to perform saving; multiple threads within that process can handle saving tasks in parallel, while saving operations across different processes are executed sequentially.
- `single_node`: Only one compute node across the entire system is allowed to perform saving; processes and threads within that node can work in parallel, while saving operations across different nodes are executed sequentially.
- `unrestricted`: No restrictions on saving operations; multiple nodes, processes, or threads are allowed to save the model simultaneously.
- `save_hf_checkpoint`: Whether to save the model in HuggingFace format. Default is `last`. Note that saving in HuggingFace format consumes additional time, storage space, and GPU memory, which may impact training performance or lead to out-of-memory errors. Options include:
- `last`: Save only the last checkpoint in HuggingFace format.
- `always`: Save all checkpoints in HuggingFace format.
- `never`: Do not save in HuggingFace format.
- `grad_clip`: Gradient clipping for updates.
- `use_dynamic_bsz`: Whether to use dynamic batch size.
- `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`.
Expand Down
18 changes: 14 additions & 4 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ buffer:
- `total_epochs`: 总训练轮数。
- `total_steps`: 总训练步数(可选)。若指定,则 `total_epochs` 不生效。

### Explorer 输入
### Explorer 输入配置

定义 explorer 用于训练和评估的数据集。

Expand Down Expand Up @@ -289,7 +289,7 @@ buffer:
- `default_reward_fn_type`: 探索过程中使用的奖励函数。若未指定,则使用 `buffer.default_reward_fn_type`。
- `workflow_args`: 用于补充数据集级别参数的字典。

### Trainer 输入
### Trainer 输入配置

定义 trainer 使用的 experience buffer 和可选的辅助数据集。

Expand Down Expand Up @@ -433,10 +433,12 @@ synchronizer:
```yaml
trainer:
name: trainer
trainer_type: 'verl'
save_interval: 100
trainer_type: "verl"
trainer_strategy: "fsdp"
total_steps: 1000
save_interval: 100
save_strategy: "unrestricted"
save_hf_checkpoint: "last"
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
Expand All @@ -446,13 +448,21 @@ trainer:

- `name`: trainer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。
- `trainer_type`: trainer 后端实现。目前仅支持 `verl`。
- `trainer_strategy`: VeRL 的训练策略。默认值为 `fsdp`。可选值如下:
- `fsdp`: 使用 PyTorch FSDP。
- `fsdp2`: 使用 PyTorch FSDP2。
- `megatron`: 使用 Megatron-LM。
- `save_interval`: 保存模型检查点的频率(步)。
- `total_steps`: 总训练步数。
- `save_strategy`: 模型保存时的并行策略。默认值为`unrestricted`。可选值如下:
- `single_thread`:整个系统中,仅允许一个线程进行模型保存,不同保存线程之间串行执行。
- `single_process`:整个系统中,仅允许一个进程执行保存,该进程内的多个线程可以并行处理保存任务,不同进程之间串行执行。
- `single_node`:整个系统中,仅允许一个计算节点执行保存,该节点内的进程和线程可并行工作,不同节点的保存串行执行。
- `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。
- `save_hf_checkpoint`: 指定保存 HuggingFace 格式检查点的时机,默认为 "last"。注意在保存为 HuggingFace 格式会消耗额外的时间、存储空间和显存,可能影响训练性能或导致显存不足错误。可选值:
- `last`: 仅训练产生的最后一个检查点保存为 HuggingFace 格式。
- `always`: 所有检查点均保存为 HuggingFace 格式。
- `never`: 不保存 HuggingFace 格式检查点。
- `grad_clip`: 梯度裁剪阈值。
- `use_dynamic_bsz`: 是否使用动态批量大小。
- `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。
Expand Down
1 change: 1 addition & 0 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ explorer:
trainer:
trainer_type: verl
save_interval: 100
save_hf_checkpoint: never
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
Expand Down
13 changes: 10 additions & 3 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_trainer(self):
eval_tasksets[0].repeat_times = 4
eval_tasksets[1].repeat_times = 4
self.config.trainer.save_interval = 4
self.config.trainer.save_hf_checkpoint = "always"
self.config.check_and_update()
_trainer_config = self.config.trainer.trainer_config
if self.strategy == "megatron":
Expand Down Expand Up @@ -134,6 +135,12 @@ def test_trainer(self):
)
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))) > 0)
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0)
self.assertTrue(
len(os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))) > 0
)
self.assertTrue(
len(os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))) > 0
)
self.assertEqual(step_num, 8)
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
# test bench mode
Expand Down Expand Up @@ -234,10 +241,10 @@ def test_trainer(self):
# self.config.buffer.batch_size = 96 # TODO: used for real testing
self.config.buffer.total_epochs = 1
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.trainer.trainer_strategy = self.fsdp_strategy
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
actor_rollout_ref = self.config.trainer.trainer_config.actor_rollout_ref
actor_rollout_ref.actor.strategy = self.fsdp_strategy
actor_rollout_ref.actor.optim.lr = 1e-5
if self.fsdp_strategy == "fsdp":
actor_rollout_ref.actor.fsdp_config.param_offload = self.offloading
Expand Down Expand Up @@ -679,16 +686,16 @@ def setUp(self):
self.config.explorer.eval_interval = 4
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.trainer.save_interval = 4
self.config.trainer.save_hf_checkpoint = "last"
self.config.trainer.trainer_strategy = self.strategy
self.config.check_and_update()

def test_trainer(self):
"""Test the checkpoint saving."""
_trainer_config = self.config.trainer.trainer_config
if self.strategy == "megatron":
_trainer_config.actor_rollout_ref.actor.strategy = "megatron"
_trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
_trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2
_trainer_config.critic.strategy = "megatron"
_trainer_config.critic.megatron.tensor_model_parallel_size = 2
_trainer_config.trainer.max_actor_ckpt_to_keep = 2
_trainer_config.trainer.max_critic_ckpt_to_keep = 2
Expand Down
13 changes: 13 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,18 @@ class ExplorerConfig:
class TrainerConfig:
name: str = TRAINER_NAME
trainer_type: str = "verl"
trainer_strategy: str = "fsdp"
save_interval: int = 0
enable_preview: bool = True # enable rollout preview in wandb
total_steps: Optional[
int
] = None # total training steps, training stops when reaching this step, None means no limit

save_hf_checkpoint: str = "last" # whether to save checkpoint in HuggingFace format
# "always": save all checkpoints in HF format
# "never": never save checkpoint in HF format
# "last": only save the last checkpoint in HF format

# trainer configs
grad_clip: float = 1.0
use_dynamic_bsz: bool = True
Expand Down Expand Up @@ -1217,6 +1223,8 @@ def __iter__(self):
setattr(new_config, field_name, stage_value)
if stage.stage_name:
new_config.name = f"{self.name}/{stage.stage_name}"
# set trainer.save_hf_checkpoint to "last" to make sure next stage can load from HF checkpoint
new_config.trainer.save_hf_checkpoint = "last"
new_config.stages = []
yield new_config

Expand Down Expand Up @@ -1398,6 +1406,11 @@ def check_and_update(self) -> Config: # noqa: C901
self.trainer.max_token_len_per_gpu = math.ceil(
2 * self.model.max_model_len / self.trainer.ulysses_sequence_parallel_size # type: ignore [operator]
)
if self.trainer.save_hf_checkpoint not in {"last", "always", "never"}:
raise ValueError(
f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, "
"must be one of 'last', 'always', or 'never'."
)
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
self.trainer.trainer_config.synchronize_config(self)
Expand Down
11 changes: 7 additions & 4 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ProfileConfig:

@dataclass
class Actor:
strategy: str = "fsdp"
strategy: Optional[str] = None
ppo_mini_batch_size: int = 256
ppo_micro_batch_size: Optional[int] = None
ppo_micro_batch_size_per_gpu: int = 1
Expand Down Expand Up @@ -232,7 +232,7 @@ class CriticModel:

@dataclass
class Critic:
strategy: str = "fsdp"
strategy: Optional[str] = None
optim: Optim = field(default_factory=Optim)
model: CriticModel = field(default_factory=CriticModel)
ppo_mini_batch_size: int = 0
Expand Down Expand Up @@ -270,7 +270,7 @@ class _RewardModel:
@dataclass
class RewardModel:
enable: bool = False
strategy: str = "fsdp"
strategy: Optional[str] = None
model: _RewardModel = field(default_factory=_RewardModel)
micro_batch_size_per_gpu: int = 1
max_length: Optional[int] = None
Expand Down Expand Up @@ -416,6 +416,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.critic.ray_namespace = config.synchronizer.ray_namespace

# Actor / Rollout Config
if self.actor_rollout_ref.actor.strategy is None:
self.actor_rollout_ref.actor.strategy = config.trainer.trainer_strategy
self.actor_rollout_ref.model.path = config.model.model_path
self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template
self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling
Expand Down Expand Up @@ -488,7 +490,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
)

# Critic config
self.critic.strategy = self.actor_rollout_ref.actor.strategy
if self.critic.strategy is None:
self.critic.strategy = config.trainer.trainer_strategy
self.critic.model.path = config.model.critic_model_path
self.critic.model.tokenizer_path = config.model.critic_model_path
self.critic.ppo_mini_batch_size = config.buffer.train_batch_size
Expand Down
7 changes: 5 additions & 2 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self, config: Config) -> None:
self.last_sync_step = None
self.last_sync_time = None
self.total_steps = config.trainer.total_steps or float("inf")
self.save_hf_checkpoint = config.trainer.save_hf_checkpoint

async def prepare(self) -> None:
"""Prepare the trainer."""
Expand Down Expand Up @@ -90,7 +91,9 @@ async def train(self) -> str:
if await self.need_sync():
metrics.update(await self.sync_weight())
if self.need_save():
metrics.update(self.save_checkpoint())
metrics.update(
self.save_checkpoint(save_as_hf=self.save_hf_checkpoint == "always")
)
if self.config.trainer.enable_preview:
self._log_experiences(repr_samples)
self.monitor.log(metrics, self.train_step_num)
Expand All @@ -101,7 +104,7 @@ async def train(self) -> str:
self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}")
break

self.save_checkpoint(block_until_saved=True, save_as_hf=True)
self.save_checkpoint(block_until_saved=True, save_as_hf=self.save_hf_checkpoint != "never")
await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
return self.config.trainer.name
Expand Down