Skip to content

Commit 0b17681

Browse files
authored
Add trainer_strategy and save_hf_checkpoint (#412)
1 parent c615ee7 commit 0b17681

File tree

7 files changed

+62
-17
lines changed

7 files changed

+62
-17
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,12 @@ Specifies the backend and behavior of the trainer.
436436
```yaml
437437
trainer:
438438
name: trainer
439-
trainer_type: 'verl'
440-
save_interval: 100
439+
trainer_type: "verl"
440+
trainer_strategy: "fsdp"
441441
total_steps: 1000
442+
save_interval: 100
442443
save_strategy: "unrestricted"
444+
save_hf_checkpoint: "last"
443445
grad_clip: 1.0
444446
use_dynamic_bsz: true
445447
max_token_len_per_gpu: 16384
@@ -449,13 +451,21 @@ trainer:
449451

450452
- `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique.
451453
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
452-
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
454+
- `trainer_strategy`: Strategy for VeRL trainer. Default is `fsdp`. Options include:
455+
- `fsdp`: Use PyTorch FSDP.
456+
- `fsdp2`: Use PyTorch FSDP2.
457+
- `megatron`: Use Megatron-LM.
453458
- `total_steps`: Total number of training steps.
459+
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
454460
- `save_strategy`: The parallel strategy used when saving the model. Defaults to `unrestricted`. The available options are as follows:
455461
- `single_thread`: Only one thread across the entire system is allowed to save the model; saving tasks from different threads are executed sequentially.
456462
- `single_process`: Only one process across the entire system is allowed to perform saving; multiple threads within that process can handle saving tasks in parallel, while saving operations across different processes are executed sequentially.
457463
- `single_node`: Only one compute node across the entire system is allowed to perform saving; processes and threads within that node can work in parallel, while saving operations across different nodes are executed sequentially.
458464
- `unrestricted`: No restrictions on saving operations; multiple nodes, processes, or threads are allowed to save the model simultaneously.
465+
- `save_hf_checkpoint`: Whether to save the model in HuggingFace format. Default is `last`. Note that saving in HuggingFace format consumes additional time, storage space, and GPU memory, which may impact training performance or lead to out-of-memory errors. Options include:
466+
- `last`: Save only the last checkpoint in HuggingFace format.
467+
- `always`: Save all checkpoints in HuggingFace format.
468+
- `never`: Do not save in HuggingFace format.
459469
- `grad_clip`: Gradient clipping for updates.
460470
- `use_dynamic_bsz`: Whether to use dynamic batch size.
461471
- `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ buffer:
227227
- `total_epochs`: 总训练轮数。
228228
- `total_steps`: 总训练步数(可选)。若指定,则 `total_epochs` 不生效。
229229

230-
### Explorer 输入
230+
### Explorer 输入配置
231231

232232
定义 explorer 用于训练和评估的数据集。
233233

@@ -289,7 +289,7 @@ buffer:
289289
- `default_reward_fn_type`: 探索过程中使用的奖励函数。若未指定,则使用 `buffer.default_reward_fn_type`
290290
- `workflow_args`: 用于补充数据集级别参数的字典。
291291

292-
### Trainer 输入
292+
### Trainer 输入配置
293293

294294
定义 trainer 使用的 experience buffer 和可选的辅助数据集。
295295

@@ -433,10 +433,12 @@ synchronizer:
433433
```yaml
434434
trainer:
435435
name: trainer
436-
trainer_type: 'verl'
437-
save_interval: 100
436+
trainer_type: "verl"
437+
trainer_strategy: "fsdp"
438438
total_steps: 1000
439+
save_interval: 100
439440
save_strategy: "unrestricted"
441+
save_hf_checkpoint: "last"
440442
grad_clip: 1.0
441443
use_dynamic_bsz: true
442444
max_token_len_per_gpu: 16384
@@ -446,13 +448,21 @@ trainer:
446448

447449
- `name`: trainer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。
448450
- `trainer_type`: trainer 后端实现。目前仅支持 `verl`。
451+
- `trainer_strategy`: VeRL 的训练策略。默认值为 `fsdp`。可选值如下:
452+
- `fsdp`: 使用 PyTorch FSDP。
453+
- `fsdp2`: 使用 PyTorch FSDP2。
454+
- `megatron`: 使用 Megatron-LM。
449455
- `save_interval`: 保存模型检查点的频率(步)。
450456
- `total_steps`: 总训练步数。
451457
- `save_strategy`: 模型保存时的并行策略。默认值为`unrestricted`。可选值如下:
452458
- `single_thread`:整个系统中,仅允许一个线程进行模型保存,不同保存线程之间串行执行。
453459
- `single_process`:整个系统中,仅允许一个进程执行保存,该进程内的多个线程可以并行处理保存任务,不同进程之间串行执行。
454460
- `single_node`:整个系统中,仅允许一个计算节点执行保存,该节点内的进程和线程可并行工作,不同节点的保存串行执行。
455461
- `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。
462+
- `save_hf_checkpoint`: 指定保存 HuggingFace 格式检查点的时机,默认为 "last"。注意在保存为 HuggingFace 格式会消耗额外的时间、存储空间和显存,可能影响训练性能或导致显存不足错误。可选值:
463+
- `last`: 仅训练产生的最后一个检查点保存为 HuggingFace 格式。
464+
- `always`: 所有检查点均保存为 HuggingFace 格式。
465+
- `never`: 不保存 HuggingFace 格式检查点。
456466
- `grad_clip`: 梯度裁剪阈值。
457467
- `use_dynamic_bsz`: 是否使用动态批量大小。
458468
- `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。

tests/template/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ explorer:
4545
trainer:
4646
trainer_type: verl
4747
save_interval: 100
48+
save_hf_checkpoint: never
4849
grad_clip: 1.0
4950
use_dynamic_bsz: true
5051
max_token_len_per_gpu: 16384

tests/trainer/trainer_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_trainer(self):
9090
eval_tasksets[0].repeat_times = 4
9191
eval_tasksets[1].repeat_times = 4
9292
self.config.trainer.save_interval = 4
93+
self.config.trainer.save_hf_checkpoint = "always"
9394
self.config.check_and_update()
9495
_trainer_config = self.config.trainer.trainer_config
9596
if self.strategy == "megatron":
@@ -134,6 +135,12 @@ def test_trainer(self):
134135
)
135136
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))) > 0)
136137
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0)
138+
self.assertTrue(
139+
len(os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))) > 0
140+
)
141+
self.assertTrue(
142+
len(os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))) > 0
143+
)
137144
self.assertEqual(step_num, 8)
138145
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
139146
# test bench mode
@@ -234,10 +241,10 @@ def test_trainer(self):
234241
# self.config.buffer.batch_size = 96 # TODO: used for real testing
235242
self.config.buffer.total_epochs = 1
236243
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
244+
self.config.trainer.trainer_strategy = self.fsdp_strategy
237245
self.config.check_and_update()
238246
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
239247
actor_rollout_ref = self.config.trainer.trainer_config.actor_rollout_ref
240-
actor_rollout_ref.actor.strategy = self.fsdp_strategy
241248
actor_rollout_ref.actor.optim.lr = 1e-5
242249
if self.fsdp_strategy == "fsdp":
243250
actor_rollout_ref.actor.fsdp_config.param_offload = self.offloading
@@ -679,16 +686,16 @@ def setUp(self):
679686
self.config.explorer.eval_interval = 4
680687
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
681688
self.config.trainer.save_interval = 4
689+
self.config.trainer.save_hf_checkpoint = "last"
690+
self.config.trainer.trainer_strategy = self.strategy
682691
self.config.check_and_update()
683692

684693
def test_trainer(self):
685694
"""Test the checkpoint saving."""
686695
_trainer_config = self.config.trainer.trainer_config
687696
if self.strategy == "megatron":
688-
_trainer_config.actor_rollout_ref.actor.strategy = "megatron"
689697
_trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
690698
_trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2
691-
_trainer_config.critic.strategy = "megatron"
692699
_trainer_config.critic.megatron.tensor_model_parallel_size = 2
693700
_trainer_config.trainer.max_actor_ckpt_to_keep = 2
694701
_trainer_config.trainer.max_critic_ckpt_to_keep = 2

trinity/common/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,12 +683,18 @@ class ExplorerConfig:
683683
class TrainerConfig:
684684
name: str = TRAINER_NAME
685685
trainer_type: str = "verl"
686+
trainer_strategy: str = "fsdp"
686687
save_interval: int = 0
687688
enable_preview: bool = True # enable rollout preview in wandb
688689
total_steps: Optional[
689690
int
690691
] = None # total training steps, training stops when reaching this step, None means no limit
691692

693+
save_hf_checkpoint: str = "last" # whether to save checkpoint in HuggingFace format
694+
# "always": save all checkpoints in HF format
695+
# "never": never save checkpoint in HF format
696+
# "last": only save the last checkpoint in HF format
697+
692698
# trainer configs
693699
grad_clip: float = 1.0
694700
use_dynamic_bsz: bool = True
@@ -1217,6 +1223,8 @@ def __iter__(self):
12171223
setattr(new_config, field_name, stage_value)
12181224
if stage.stage_name:
12191225
new_config.name = f"{self.name}/{stage.stage_name}"
1226+
# set trainer.save_hf_checkpoint to "last" to make sure next stage can load from HF checkpoint
1227+
new_config.trainer.save_hf_checkpoint = "last"
12201228
new_config.stages = []
12211229
yield new_config
12221230

@@ -1398,6 +1406,11 @@ def check_and_update(self) -> Config: # noqa: C901
13981406
self.trainer.max_token_len_per_gpu = math.ceil(
13991407
2 * self.model.max_model_len / self.trainer.ulysses_sequence_parallel_size # type: ignore [operator]
14001408
)
1409+
if self.trainer.save_hf_checkpoint not in {"last", "always", "never"}:
1410+
raise ValueError(
1411+
f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, "
1412+
"must be one of 'last', 'always', or 'never'."
1413+
)
14011414
else:
14021415
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
14031416
self.trainer.trainer_config.synchronize_config(self)

trinity/common/verl_config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from omegaconf import OmegaConf
77

8-
from trinity.common.config import Config, SynchronizerConfig
8+
from trinity.common.config import Config, SynchronizerConfig, set_if_none
99
from trinity.common.constants import EXPLORER_NAME
1010
from trinity.utils.log import get_logger
1111

@@ -137,7 +137,7 @@ class ProfileConfig:
137137

138138
@dataclass
139139
class Actor:
140-
strategy: str = "fsdp"
140+
strategy: Optional[str] = None
141141
ppo_mini_batch_size: int = 256
142142
ppo_micro_batch_size: Optional[int] = None
143143
ppo_micro_batch_size_per_gpu: int = 1
@@ -232,7 +232,7 @@ class CriticModel:
232232

233233
@dataclass
234234
class Critic:
235-
strategy: str = "fsdp"
235+
strategy: Optional[str] = None
236236
optim: Optim = field(default_factory=Optim)
237237
model: CriticModel = field(default_factory=CriticModel)
238238
ppo_mini_batch_size: int = 0
@@ -270,7 +270,7 @@ class _RewardModel:
270270
@dataclass
271271
class RewardModel:
272272
enable: bool = False
273-
strategy: str = "fsdp"
273+
strategy: Optional[str] = None
274274
model: _RewardModel = field(default_factory=_RewardModel)
275275
micro_batch_size_per_gpu: int = 1
276276
max_length: Optional[int] = None
@@ -416,6 +416,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
416416
self.critic.ray_namespace = config.synchronizer.ray_namespace
417417

418418
# Actor / Rollout Config
419+
set_if_none(self.actor_rollout_ref.actor, "strategy", config.trainer.trainer_strategy)
419420
self.actor_rollout_ref.model.path = config.model.model_path
420421
self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template
421422
self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling
@@ -488,7 +489,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
488489
)
489490

490491
# Critic config
491-
self.critic.strategy = self.actor_rollout_ref.actor.strategy
492+
set_if_none(self.critic, "strategy", config.trainer.trainer_strategy)
492493
self.critic.model.path = config.model.critic_model_path
493494
self.critic.model.tokenizer_path = config.model.critic_model_path
494495
self.critic.ppo_mini_batch_size = config.buffer.train_batch_size

trinity/trainer/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, config: Config) -> None:
6363
self.last_sync_step = None
6464
self.last_sync_time = None
6565
self.total_steps = config.trainer.total_steps or float("inf")
66+
self.save_hf_checkpoint = config.trainer.save_hf_checkpoint
6667

6768
async def prepare(self) -> None:
6869
"""Prepare the trainer."""
@@ -90,7 +91,9 @@ async def train(self) -> str:
9091
if await self.need_sync():
9192
metrics.update(await self.sync_weight())
9293
if self.need_save():
93-
metrics.update(self.save_checkpoint())
94+
metrics.update(
95+
self.save_checkpoint(save_as_hf=self.save_hf_checkpoint == "always")
96+
)
9497
if self.config.trainer.enable_preview:
9598
self._log_experiences(repr_samples)
9699
self.monitor.log(metrics, self.train_step_num)
@@ -101,7 +104,7 @@ async def train(self) -> str:
101104
self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}")
102105
break
103106

104-
self.save_checkpoint(block_until_saved=True, save_as_hf=True)
107+
self.save_checkpoint(block_until_saved=True, save_as_hf=self.save_hf_checkpoint != "never")
105108
await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)
106109
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
107110
return self.config.trainer.name

0 commit comments

Comments
 (0)