Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ trainer:
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
max_checkpoints_to_keep: 5
trainer_config: null
```

Expand All @@ -499,6 +500,7 @@ trainer:
- `use_dynamic_bsz`: Whether to use dynamic batch size.
- `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`.
- `ulysses_sequence_parallel_size`: Sequence parallel size.
- `max_checkpoints_to_keep`: Maximum number of checkpoints to keep. Older checkpoints will be deleted. If not specified, all checkpoints will be kept.
- `trainer_config`: The trainer configuration provided inline.
---

Expand Down
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ trainer:
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
max_checkpoints_to_keep: 5
trainer_config: null
```

Expand All @@ -496,6 +497,7 @@ trainer:
- `use_dynamic_bsz`: 是否使用动态批量大小。
- `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。
- `ulysses_sequence_parallel_size`: 序列并行的并行度,即用于分割单个序列的 GPU 数量。
- `max_checkpoints_to_keep`: 保留的最大检查点数量。超过此数量后,最旧的检查点将被删除。
- `trainer_config`: 内联提供的 trainer 配置。

---
Expand Down
11 changes: 10 additions & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def setUp(self):
self.config.check_and_update()
self.process_list = []

def test_trainer(self):
def test_trainer(self): # noqa: C901
"""Test the checkpoint saving."""
_trainer_config = self.config.trainer.trainer_config
if self.strategy == "megatron":
Expand Down Expand Up @@ -839,6 +839,10 @@ def test_trainer(self):
# print(f"State dict check at {state_dict_iteration} iteration passed.") # for debug

if checkpoint_iteration > 0:
flag_file = os.path.join(
default_local_dir, f"global_step_{checkpoint_iteration}", ".full_checkpoint"
)
self.assertTrue(os.path.exists(flag_file))
for sub_dir_name in ["critic", "actor"]:
iteration_dir = os.path.join(
default_local_dir, f"global_step_{checkpoint_iteration}", sub_dir_name
Expand Down Expand Up @@ -882,6 +886,11 @@ def test_trainer(self):
# print(f"Checkpoint check at {checkpoint_iteration} iteration passed.") # for debug
if not stop_event.is_set():
self.fail("Training process failed to stop.")
# check only full checkpoint dirs are kept
for sync_step in [0, 1, 2, 3]:
state_dict_dir = os.path.join(default_local_dir, f"global_step_{sync_step}")
self.assertFalse(os.path.exists(state_dict_dir))
self.assertTrue(os.path.exists(os.path.join(default_local_dir, "global_step_4")))
trainer_process.join(timeout=10)
self.assertIn("model.safetensors", huggingface_dir_files)

Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ class TrainerConfig:
# TODO: extract more train-related params from underlying trainer engine

save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED
max_checkpoints_to_keep: Optional[int] = None

trainer_config: Any = field(default_factory=dict)
trainer_config_path: str = "" # deprecated, use `trainer_config` instead
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, T
Args:
checkpoint_dir (str): The checkpoint directory.
trainer_type (str): The trainer type. Only support "verl" for now.

Returns:
Union[dict, Tuple[str, str]]: The state dict. If the checkpoint uses
megatron dist checkpointing, return a tuple of (method, checkpoint_dir).
"""
if config.trainer_type == "verl":
strategy = config.trainer_strategy
Expand Down
1 change: 1 addition & 0 deletions trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def update_weight(self):
if self._weight_update_rank == 0:
state_dict, model_version = ray.get(self.synchronizer.get_model_state_dict.remote())
if isinstance(state_dict, tuple):
# currently only megatron return a tuple
method, checkpoint_dir = state_dict
if method == "megatron":
if self._checkpoint_converter is None:
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.trainer.group_name = config.group
self.trainer.experiment_name = config.name
self.trainer.default_local_dir = config.checkpoint_job_dir
if config.trainer.max_checkpoints_to_keep is not None:
self.trainer.max_actor_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
self.trainer.max_critic_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
if not config.continue_from_checkpoint:
self.trainer.resume_mode = "disable"
else:
Expand Down
27 changes: 22 additions & 5 deletions trinity/manager/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import os
import shutil
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -95,13 +96,14 @@ async def _find_verl_latest_state_dict(self) -> None:
)
while True:
if os.path.exists(local_latest_state_dict_iteration):
current_model_version = self.model_version
try:
with open(local_latest_state_dict_iteration, "r") as f:
latest_model_version = int(f.read().strip())
except (IOError, ValueError) as e:
self.logger.warning(f"Failed to read or parse state dict iteration file: {e}")
continue
if latest_model_version > self.model_version:
if latest_model_version > current_model_version:
self.logger.info(
f"Synchronizer has found a new model state dict at step {latest_model_version}."
)
Expand All @@ -119,8 +121,25 @@ async def _find_verl_latest_state_dict(self) -> None:
f"Synchronizer has loaded model state dict from checkpoint {latest_model_version}."
)
await self.set_model_state_dict(model_state_dict, latest_model_version)
# remove the previous checkpoints to save disk space
await self._remove_previous_state_dict(current_model_version)
await asyncio.sleep(1)

async def _remove_previous_state_dict(self, previous_model_version: int) -> None:
self.logger.info(
f"Synchronizer is removing previous checkpoint for sync at step {previous_model_version}."
)
previous_state_dict_dir = os.path.join(
self.config.checkpoint_job_dir, f"global_step_{previous_model_version}"
)
if os.path.exists(previous_state_dict_dir):
# check if it's a full checkpoint, only remove checkpoints for sync
if not os.path.exists(os.path.join(previous_state_dict_dir, ".full_checkpoint")):
self.logger.info(
f"Removing previous checkpoint for sync at step {previous_model_version}."
)
shutil.rmtree(previous_state_dict_dir)

async def _find_tinker_latest_state_dict(self) -> None:
default_local_dir = self.config.checkpoint_job_dir
local_latest_state_dict_iteration = os.path.join(
Expand Down Expand Up @@ -320,17 +339,15 @@ async def get_latest_model_version(self) -> int:
async with self._ready_condition:
return self.model_version

async def ready_to_nccl_sync(
self, module: str, trainer_step: Optional[int] = None
) -> Union[int, None]:
async def ready_to_nccl_sync(self, module: str, trainer_step: int) -> Union[int, None]:
"""
Prepare for NCCL-based synchronization between modules.

Only supports one explorer currently.

Args:
module: Either 'trainer' or 'explorer'.
trainer_step: Optional step number from the trainer.
trainer_step: Step number from the trainer.

Returns:
The model version if both sides are ready; otherwise None.
Expand Down
9 changes: 9 additions & 0 deletions trinity/trainer/tinker_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
f"global_step_{self.train_step_num}",
)
os.makedirs(local_path, exist_ok=True)

# save a flag to indicate this is a full checkpoint dir
# make sure this flag is created before notifying the synchronizer
# to avoid the synchronizer recognizing it as a state_dict-only checkpoint
# TODO: use a better way to indicate full checkpoint
flag_path = os.path.join(local_path, ".full_checkpoint")
with open(flag_path, "w") as f:
f.write("")

remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt")
with open(remote_checkpoint_path, "w") as f:
f.write(self.latest_remote_checkpoint_path)
Expand Down
8 changes: 8 additions & 0 deletions trinity/trainer/verl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,14 @@ def _save_checkpoint(self, save_as_hf: bool = False):
self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
)

# save a flag to indicate this is a full checkpoint dir
# make sure this flag is created before notifying the synchronizer
# to avoid the synchronizer recognizing it as a state_dict-only checkpoint
# TODO: use a better way to indicate full checkpoint
flag_path = os.path.join(local_global_step_folder, ".full_checkpoint")
with open(flag_path, "w") as f:
f.write("")

self.logger.info(f"local_global_step_folder: {local_global_step_folder}")
actor_local_path = os.path.join(local_global_step_folder, "actor")

Expand Down