From 05810bd2f971c69b5c029805427f7bbf382668a4 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 26 Nov 2025 20:06:52 +0800 Subject: [PATCH 1/5] add trainer_strategy and save_hf_checkpoint --- .../source/tutorial/trinity_configs.md | 16 +++++++++++++--- .../source_zh/tutorial/trinity_configs.md | 16 +++++++++++++--- tests/template/config.yaml | 1 + tests/trainer/trainer_test.py | 13 ++++++++++--- trinity/common/config.py | 6 ++++++ trinity/common/verl_config.py | 11 +++++++---- trinity/trainer/trainer.py | 7 +++++-- 7 files changed, 55 insertions(+), 15 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 9d4aba3df0..c07fe07921 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -436,10 +436,12 @@ Specifies the backend and behavior of the trainer. ```yaml trainer: name: trainer - trainer_type: 'verl' - save_interval: 100 + trainer_type: "verl" + trainer_strategy: "fsdp" total_steps: 1000 + save_interval: 100 save_strategy: "unrestricted" + save_hf_checkpoint: "last" grad_clip: 1.0 use_dynamic_bsz: true max_token_len_per_gpu: 16384 @@ -449,13 +451,21 @@ trainer: - `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique. - `trainer_type`: Trainer backend implementation. Currently only supports `verl`. -- `save_interval`: Frequency (in steps) at which to save model checkpoints. +- `trainer_strategy`: Strategy for VeRL trainer. Default is `fsdp`. Options include: + - `fsdp`: Use PyTorch FSDP. + - `fsdp2`: Use PyTorch FSDP2. + - `megatron`: Use Megatron-LM. - `total_steps`: Total number of training steps. +- `save_interval`: Frequency (in steps) at which to save model checkpoints. - `save_strategy`: The parallel strategy used when saving the model. Defaults to `unrestricted`. The available options are as follows: - `single_thread`: Only one thread across the entire system is allowed to save the model; saving tasks from different threads are executed sequentially. - `single_process`: Only one process across the entire system is allowed to perform saving; multiple threads within that process can handle saving tasks in parallel, while saving operations across different processes are executed sequentially. - `single_node`: Only one compute node across the entire system is allowed to perform saving; processes and threads within that node can work in parallel, while saving operations across different nodes are executed sequentially. - `unrestricted`: No restrictions on saving operations; multiple nodes, processes, or threads are allowed to save the model simultaneously. +- `save_hf_checkpoint`: Whether to save the model in HuggingFace format. Default is `last`. Note that saving in HuggingFace format consumes additional time, storage space, and GPU memory, which may impact training performance or lead to out-of-memory errors. Options include: + - `last`: Save only the last checkpoint in HuggingFace format. + - `always`: Save all checkpoints in HuggingFace format. + - `never`: Do not save in HuggingFace format. - `grad_clip`: Gradient clipping for updates. - `use_dynamic_bsz`: Whether to use dynamic batch size. - `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 1d45d43b56..b6922c434d 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -227,7 +227,7 @@ buffer: - `total_epochs`: 总训练轮数。 - `total_steps`: 总训练步数(可选)。若指定,则 `total_epochs` 不生效。 -### Explorer 输入 +### Explorer 输入配置 定义 explorer 用于训练和评估的数据集。 @@ -289,7 +289,7 @@ buffer: - `default_reward_fn_type`: 探索过程中使用的奖励函数。若未指定,则使用 `buffer.default_reward_fn_type`。 - `workflow_args`: 用于补充数据集级别参数的字典。 -### Trainer 输入 +### Trainer 输入配置 定义 trainer 使用的 experience buffer 和可选的辅助数据集。 @@ -434,9 +434,11 @@ synchronizer: trainer: name: trainer trainer_type: 'verl' - save_interval: 100 + trainer_strategy: "fsdp" total_steps: 1000 + save_interval: 100 save_strategy: "unrestricted" + save_hf_checkpoint: "last" grad_clip: 1.0 use_dynamic_bsz: true max_token_len_per_gpu: 16384 @@ -446,6 +448,10 @@ trainer: - `name`: trainer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。 - `trainer_type`: trainer 后端实现。目前仅支持 `verl`。 +- `trainer_strategy`: VeRL 的训练策略。默认值为 `fsdp`。可选值如下: + - `fsdp`: 使用 PyTorch FSDP。 + - `fsdp2`: 使用 PyTorch FSDP2。 + - `megatron`: 使用 Megatron-LM。 - `save_interval`: 保存模型检查点的频率(步)。 - `total_steps`: 总训练步数。 - `save_strategy`: 模型保存时的并行策略。默认值为`unrestricted`。可选值如下: @@ -453,6 +459,10 @@ trainer: - `single_process`:整个系统中,仅允许一个进程执行保存,该进程内的多个线程可以并行处理保存任务,不同进程之间串行执行。 - `single_node`:整个系统中,仅允许一个计算节点执行保存,该节点内的进程和线程可并行工作,不同节点的保存串行执行。 - `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。 +- `save_hf_checkpoint`: 指定保存 HuggingFace 格式检查点的时机,默认为 "last"。可选值: + - `last`: 仅训练产生的最后一个检查点保存为 HuggingFace 格式。 + - `always`: 所有检查点均保存为 HuggingFace 格式。 + - `never`: 不保存 HuggingFace 格式检查点。 - `grad_clip`: 梯度裁剪阈值。 - `use_dynamic_bsz`: 是否使用动态批量大小。 - `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。 diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 425b327a98..745dcc50a1 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -45,6 +45,7 @@ explorer: trainer: trainer_type: verl save_interval: 100 + save_hf_checkpoint: never grad_clip: 1.0 use_dynamic_bsz: true max_token_len_per_gpu: 16384 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 397c468a8b..c22780d811 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -90,6 +90,7 @@ def test_trainer(self): eval_tasksets[0].repeat_times = 4 eval_tasksets[1].repeat_times = 4 self.config.trainer.save_interval = 4 + self.config.trainer.save_hf_checkpoint = "always" self.config.check_and_update() _trainer_config = self.config.trainer.trainer_config if self.strategy == "megatron": @@ -134,6 +135,12 @@ def test_trainer(self): ) self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_4, "actor"))) > 0) self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0) + self.assertTrue( + len(os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface"))) > 0 + ) + self.assertTrue( + len(os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface"))) > 0 + ) self.assertEqual(step_num, 8) ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace) # test bench mode @@ -234,10 +241,10 @@ def test_trainer(self): # self.config.buffer.batch_size = 96 # TODO: used for real testing self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.trainer.trainer_strategy = self.fsdp_strategy self.config.check_and_update() self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 actor_rollout_ref = self.config.trainer.trainer_config.actor_rollout_ref - actor_rollout_ref.actor.strategy = self.fsdp_strategy actor_rollout_ref.actor.optim.lr = 1e-5 if self.fsdp_strategy == "fsdp": actor_rollout_ref.actor.fsdp_config.param_offload = self.offloading @@ -679,16 +686,16 @@ def setUp(self): self.config.explorer.eval_interval = 4 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.trainer.save_interval = 4 + self.config.trainer.save_hf_checkpoint = "last" + self.config.trainer.trainer_strategy = self.strategy self.config.check_and_update() def test_trainer(self): """Test the checkpoint saving.""" _trainer_config = self.config.trainer.trainer_config if self.strategy == "megatron": - _trainer_config.actor_rollout_ref.actor.strategy = "megatron" _trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2 _trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2 - _trainer_config.critic.strategy = "megatron" _trainer_config.critic.megatron.tensor_model_parallel_size = 2 _trainer_config.trainer.max_actor_ckpt_to_keep = 2 _trainer_config.trainer.max_critic_ckpt_to_keep = 2 diff --git a/trinity/common/config.py b/trinity/common/config.py index 21a257cca6..a8afedee37 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -683,12 +683,18 @@ class ExplorerConfig: class TrainerConfig: name: str = TRAINER_NAME trainer_type: str = "verl" + trainer_strategy: str = "fsdp" save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb total_steps: Optional[ int ] = None # total training steps, training stops when reaching this step, None means no limit + save_hf_checkpoint: str = "last" # whether to save checkpoint in HuggingFace format + # "always": save all checkpoints in HF format + # "never": never save checkpoint in HF format + # "last": only save the last checkpoint in HF format + # trainer configs grad_clip: float = 1.0 use_dynamic_bsz: bool = True diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index eddcbba727..055ee32c01 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -137,7 +137,7 @@ class ProfileConfig: @dataclass class Actor: - strategy: str = "fsdp" + strategy: Optional[str] = None ppo_mini_batch_size: int = 256 ppo_micro_batch_size: Optional[int] = None ppo_micro_batch_size_per_gpu: int = 1 @@ -232,7 +232,7 @@ class CriticModel: @dataclass class Critic: - strategy: str = "fsdp" + strategy: Optional[str] = None optim: Optim = field(default_factory=Optim) model: CriticModel = field(default_factory=CriticModel) ppo_mini_batch_size: int = 0 @@ -270,7 +270,7 @@ class _RewardModel: @dataclass class RewardModel: enable: bool = False - strategy: str = "fsdp" + strategy: Optional[str] = None model: _RewardModel = field(default_factory=_RewardModel) micro_batch_size_per_gpu: int = 1 max_length: Optional[int] = None @@ -416,6 +416,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.critic.ray_namespace = config.synchronizer.ray_namespace # Actor / Rollout Config + if self.actor_rollout_ref.actor.strategy is None: + self.actor_rollout_ref.actor.strategy = config.trainer.trainer_strategy self.actor_rollout_ref.model.path = config.model.model_path self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling @@ -488,7 +490,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 ) # Critic config - self.critic.strategy = self.actor_rollout_ref.actor.strategy + if self.critic.strategy is None: + self.critic.strategy = config.trainer.trainer_strategy self.critic.model.path = config.model.critic_model_path self.critic.model.tokenizer_path = config.model.critic_model_path self.critic.ppo_mini_batch_size = config.buffer.train_batch_size diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 8b93c32dbf..c42fccfb26 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -63,6 +63,7 @@ def __init__(self, config: Config) -> None: self.last_sync_step = None self.last_sync_time = None self.total_steps = config.trainer.total_steps or float("inf") + self.save_hf_checkpoint = config.trainer.save_hf_checkpoint async def prepare(self) -> None: """Prepare the trainer.""" @@ -90,7 +91,9 @@ async def train(self) -> str: if await self.need_sync(): metrics.update(await self.sync_weight()) if self.need_save(): - metrics.update(self.save_checkpoint()) + metrics.update( + self.save_checkpoint(save_as_hf=self.save_hf_checkpoint == "always") + ) if self.config.trainer.enable_preview: self._log_experiences(repr_samples) self.monitor.log(metrics, self.train_step_num) @@ -101,7 +104,7 @@ async def train(self) -> str: self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}") break - self.save_checkpoint(block_until_saved=True, save_as_hf=True) + self.save_checkpoint(block_until_saved=True, save_as_hf=self.save_hf_checkpoint != "never") await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED) self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name From af80edc35c97ac5ee737b940e2d46abc59bb368c Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 26 Nov 2025 20:11:21 +0800 Subject: [PATCH 2/5] fix comments --- docs/sphinx_doc/source_zh/tutorial/trinity_configs.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index b6922c434d..7e7aec64b8 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -433,7 +433,7 @@ synchronizer: ```yaml trainer: name: trainer - trainer_type: 'verl' + trainer_type: "verl" trainer_strategy: "fsdp" total_steps: 1000 save_interval: 100 @@ -459,7 +459,7 @@ trainer: - `single_process`:整个系统中,仅允许一个进程执行保存,该进程内的多个线程可以并行处理保存任务,不同进程之间串行执行。 - `single_node`:整个系统中,仅允许一个计算节点执行保存,该节点内的进程和线程可并行工作,不同节点的保存串行执行。 - `unrestricted`:不限制保存操作,允许多个节点、进程或线程同时保存模型。 -- `save_hf_checkpoint`: 指定保存 HuggingFace 格式检查点的时机,默认为 "last"。可选值: +- `save_hf_checkpoint`: 指定保存 HuggingFace 格式检查点的时机,默认为 "last"。注意在保存为 HuggingFace 格式会消耗额外的时间、存储空间和显存,可能影响训练性能或导致显存不足错误。可选值: - `last`: 仅训练产生的最后一个检查点保存为 HuggingFace 格式。 - `always`: 所有检查点均保存为 HuggingFace 格式。 - `never`: 不保存 HuggingFace 格式检查点。 From abea03570f456847b3104af94ed7de3ad8a72004 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 26 Nov 2025 21:48:46 +0800 Subject: [PATCH 3/5] fix multi-stage training --- trinity/common/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trinity/common/config.py b/trinity/common/config.py index a8afedee37..1b9e392638 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -1223,6 +1223,8 @@ def __iter__(self): setattr(new_config, field_name, stage_value) if stage.stage_name: new_config.name = f"{self.name}/{stage.stage_name}" + # set trainer.save_hf_checkpoint to "last" to make sure next stage can load from HF checkpoint + new_config.trainer.save_hf_checkpoint = "last" new_config.stages = [] yield new_config From 04e596d805f8b2157d54d83254a519c8b09f0d04 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 27 Nov 2025 10:09:07 +0800 Subject: [PATCH 4/5] add save_hf_checkpoint check --- trinity/common/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trinity/common/config.py b/trinity/common/config.py index 1b9e392638..904aac0581 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -1406,6 +1406,11 @@ def check_and_update(self) -> Config: # noqa: C901 self.trainer.max_token_len_per_gpu = math.ceil( 2 * self.model.max_model_len / self.trainer.ulysses_sequence_parallel_size # type: ignore [operator] ) + if self.trainer.save_hf_checkpoint not in {"last", "always", "never"}: + raise ValueError( + f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, " + "must be one of 'last', 'always', or 'never'." + ) else: raise ValueError(f"Invalid trainer type: {self.trainer_type}") self.trainer.trainer_config.synchronize_config(self) From 1551e084fef3f3c4f239896ff8253124c6b985e9 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 27 Nov 2025 14:26:37 +0800 Subject: [PATCH 5/5] fix comments --- trinity/common/verl_config.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 055ee32c01..b157c343e7 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf -from trinity.common.config import Config, SynchronizerConfig +from trinity.common.config import Config, SynchronizerConfig, set_if_none from trinity.common.constants import EXPLORER_NAME from trinity.utils.log import get_logger @@ -416,8 +416,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.critic.ray_namespace = config.synchronizer.ray_namespace # Actor / Rollout Config - if self.actor_rollout_ref.actor.strategy is None: - self.actor_rollout_ref.actor.strategy = config.trainer.trainer_strategy + set_if_none(self.actor_rollout_ref.actor, "strategy", config.trainer.trainer_strategy) self.actor_rollout_ref.model.path = config.model.model_path self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling @@ -490,8 +489,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 ) # Critic config - if self.critic.strategy is None: - self.critic.strategy = config.trainer.trainer_strategy + set_if_none(self.critic, "strategy", config.trainer.trainer_strategy) self.critic.model.path = config.model.critic_model_path self.critic.model.tokenizer_path = config.model.critic_model_path self.critic.ppo_mini_batch_size = config.buffer.train_batch_size