Skip to content

Commit 2be8bb6

Browse files
authored
Remove checkpoints saved for sync purpose (#476)
1 parent 34c0eab commit 2be8bb6

File tree

13 files changed

+124
-12
lines changed

13 files changed

+124
-12
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ trainer:
475475
use_dynamic_bsz: true
476476
max_token_len_per_gpu: 16384
477477
ulysses_sequence_parallel_size: 1
478+
max_checkpoints_to_keep: 5
478479
trainer_config: null
479480
```
480481

@@ -499,6 +500,7 @@ trainer:
499500
- `use_dynamic_bsz`: Whether to use dynamic batch size.
500501
- `max_token_len_per_gpu`: The maximum number of tokens to be processed in forward and backward when updating the policy. Effective when `use_dynamic_bsz=true`.
501502
- `ulysses_sequence_parallel_size`: Sequence parallel size.
503+
- `max_checkpoints_to_keep`: Maximum number of checkpoints to keep. Older checkpoints will be deleted. If not specified, all checkpoints will be kept.
502504
- `trainer_config`: The trainer configuration provided inline.
503505
---
504506

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ trainer:
472472
use_dynamic_bsz: true
473473
max_token_len_per_gpu: 16384
474474
ulysses_sequence_parallel_size: 1
475+
max_checkpoints_to_keep: 5
475476
trainer_config: null
476477
```
477478

@@ -496,6 +497,7 @@ trainer:
496497
- `use_dynamic_bsz`: 是否使用动态批量大小。
497498
- `max_token_len_per_gpu`: 训练过程中,每个 GPU 最大 token 长度; 当 `use_dynamic_bsz=true` 时生效。
498499
- `ulysses_sequence_parallel_size`: 序列并行的并行度,即用于分割单个序列的 GPU 数量。
500+
- `max_checkpoints_to_keep`: 保留的最大检查点数量。超过此数量后,最旧的检查点将被删除。如果未指定,则将保留所有检查点。
499501
- `trainer_config`: 内联提供的 trainer 配置。
500502

501503
---

tests/trainer/trainer_test.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ def test_trainer(self, mock_load):
344344

345345
# sft warmup stage
346346
sft_config = stage_configs[0]
347+
self.assertEqual(
348+
sft_config.synchronizer.sync_interval,
349+
sft_config.trainer.save_interval,
350+
)
347351
parser = TensorBoardParser(os.path.join(sft_config.monitor.cache_dir, "tensorboard"))
348352
rollout_metrics = parser.metric_list("rollout")
349353
self.assertEqual(len(rollout_metrics), 0)
@@ -374,11 +378,15 @@ def test_trainer(self, mock_load):
374378
self.assertEqual(parser.metric_min_step(response_metrics[0]), 1)
375379
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
376380
# test save checkpoint when sft finish
381+
for i in range(3):
382+
self.assertFalse(
383+
os.path.exists(os.path.join(sft_config.checkpoint_job_dir, f"global_step_{i}"))
384+
)
377385
self.assertEqual(
378386
get_checkpoint_dir_with_step_num(
379-
checkpoint_root_path=sft_config.checkpoint_job_dir, trainer_type="verl", step_num=2
387+
checkpoint_root_path=sft_config.checkpoint_job_dir, trainer_type="verl", step_num=3
380388
)[1],
381-
2,
389+
3,
382390
)
383391
# test save checkpoint at last step
384392
checkpoint_dir, step_num = get_checkpoint_dir_with_step_num(
@@ -749,7 +757,7 @@ def setUp(self):
749757
if multiprocessing.get_start_method(allow_none=True) != "spawn":
750758
multiprocessing.set_start_method("spawn", force=True)
751759
self.config = get_template_config()
752-
self.config.buffer.total_epochs = 1
760+
self.config.buffer.total_steps = 6
753761
self.config.buffer.batch_size = 4
754762
self.config.model.model_path = get_model_path()
755763
self.config.explorer.rollout_model.engine_type = "vllm_async"
@@ -762,21 +770,20 @@ def setUp(self):
762770
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
763771
self.config.explorer.eval_interval = 4
764772
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
765-
self.config.trainer.save_interval = 4
773+
self.config.trainer.save_interval = 2
766774
self.config.trainer.save_hf_checkpoint = "last"
767775
self.config.trainer.trainer_strategy = self.strategy
776+
self.config.trainer.max_checkpoints_to_keep = 2
768777
self.config.check_and_update()
769778
self.process_list = []
770779

771-
def test_trainer(self):
780+
def test_trainer(self): # noqa: C901
772781
"""Test the checkpoint saving."""
773782
_trainer_config = self.config.trainer.trainer_config
774783
if self.strategy == "megatron":
775784
_trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
776785
_trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2
777786
_trainer_config.critic.megatron.tensor_model_parallel_size = 2
778-
_trainer_config.trainer.max_actor_ckpt_to_keep = 2
779-
_trainer_config.trainer.max_critic_ckpt_to_keep = 2
780787

781788
stop_event = multiprocessing.Event()
782789
trainer_process = multiprocessing.Process(target=run_both, args=(self.config, stop_event))
@@ -839,6 +846,10 @@ def test_trainer(self):
839846
# print(f"State dict check at {state_dict_iteration} iteration passed.") # for debug
840847

841848
if checkpoint_iteration > 0:
849+
flag_file = os.path.join(
850+
default_local_dir, f"global_step_{checkpoint_iteration}", ".full_checkpoint"
851+
)
852+
self.assertTrue(os.path.exists(flag_file))
842853
for sub_dir_name in ["critic", "actor"]:
843854
iteration_dir = os.path.join(
844855
default_local_dir, f"global_step_{checkpoint_iteration}", sub_dir_name
@@ -882,6 +893,28 @@ def test_trainer(self):
882893
# print(f"Checkpoint check at {checkpoint_iteration} iteration passed.") # for debug
883894
if not stop_event.is_set():
884895
self.fail("Training process failed to stop.")
896+
# check only full checkpoint dirs are kept
897+
for sync_step in [1, 3, 5]:
898+
state_dict_dir = os.path.join(default_local_dir, f"global_step_{sync_step}")
899+
self.assertFalse(
900+
os.path.exists(state_dict_dir),
901+
f"Found unexpected state dict dir at step {sync_step}",
902+
)
903+
for checkpoint_step in [4, 6]:
904+
checkpoint_dir = os.path.join(default_local_dir, f"global_step_{checkpoint_step}")
905+
self.assertTrue(
906+
os.path.exists(checkpoint_dir),
907+
f"Missing expected checkpoint dir at step {checkpoint_step}",
908+
)
909+
actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor")
910+
self.assertTrue(os.path.exists(actor_checkpoint_dir))
911+
# check step 2 should have no checkpoint
912+
checkpoint_dir = os.path.join(default_local_dir, "global_step_2")
913+
self.assertTrue(os.path.exists(checkpoint_dir))
914+
actor_checkpoint_dir = os.path.join(checkpoint_dir, "actor")
915+
self.assertFalse(os.path.exists(actor_checkpoint_dir))
916+
critic_checkpoint_dir = os.path.join(checkpoint_dir, "critic")
917+
self.assertFalse(os.path.exists(critic_checkpoint_dir))
885918
trainer_process.join(timeout=10)
886919
self.assertIn("model.safetensors", huggingface_dir_files)
887920

trinity/algorithm/algorithm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,27 @@ def default_config(cls) -> Dict:
6363
"entropy_loss_fn": "none",
6464
}
6565

66+
@classmethod
67+
def check_config(cls, config: Config) -> None:
68+
if config.mode == "train":
69+
if (
70+
config.buffer.trainer_input.experience_buffer is None
71+
or not config.buffer.trainer_input.experience_buffer.path
72+
):
73+
raise ValueError(
74+
"`buffer.trainer_input.experience_buffer.path` is required when `algorithm.algorithm_type == sft`"
75+
)
76+
elif config.mode in ["both", "explore"]:
77+
raise ValueError(f"SFT does not support `{config.mode}` mode")
78+
79+
if config.synchronizer.sync_method != SyncMethod.CHECKPOINT:
80+
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
81+
logger.warning(
82+
"SFT only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
83+
)
84+
85+
config.synchronizer.sync_interval = config.trainer.save_interval
86+
6687

6788
class PPOAlgorithm(AlgorithmType):
6889
"""PPO Algorithm."""
@@ -232,6 +253,7 @@ def check_config(cls, config: Config) -> None:
232253
logger.warning(
233254
"DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
234255
)
256+
config.synchronizer.sync_interval = config.trainer.save_interval
235257
if config.algorithm.repeat_times != 2:
236258
config.algorithm.repeat_times = 2 # Fake repeat times
237259
if config.algorithm.kl_loss_fn in {"none", None}:

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ class TrainerConfig:
735735
# TODO: extract more train-related params from underlying trainer engine
736736

737737
save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED
738+
max_checkpoints_to_keep: Optional[int] = None
738739

739740
trainer_config: Any = field(default_factory=dict)
740741
trainer_config_path: str = "" # deprecated, use `trainer_config` instead

trinity/common/models/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def load_state_dict(checkpoint_dir: str, config: TrainerConfig) -> Union[dict, T
199199
Args:
200200
checkpoint_dir (str): The checkpoint directory.
201201
trainer_type (str): The trainer type. Only support "verl" for now.
202+
203+
Returns:
204+
Union[dict, Tuple[str, str]]: The state dict. If the checkpoint uses
205+
megatron dist checkpointing, return a tuple of (method, checkpoint_dir).
202206
"""
203207
if config.trainer_type == "verl":
204208
strategy = config.trainer_strategy

trinity/common/models/vllm_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def update_weight(self):
6868
if self._weight_update_rank == 0:
6969
state_dict, model_version = ray.get(self.synchronizer.get_model_state_dict.remote())
7070
if isinstance(state_dict, tuple):
71+
# currently only megatron return a tuple
7172
method, checkpoint_dir = state_dict
7273
if method == "megatron":
7374
if self._checkpoint_converter is None:

trinity/common/verl_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,9 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
413413
self.trainer.group_name = config.group
414414
self.trainer.experiment_name = config.name
415415
self.trainer.default_local_dir = config.checkpoint_job_dir
416+
if config.trainer.max_checkpoints_to_keep is not None:
417+
self.trainer.max_actor_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
418+
self.trainer.max_critic_ckpt_to_keep = config.trainer.max_checkpoints_to_keep
416419
if not config.continue_from_checkpoint:
417420
self.trainer.resume_mode = "disable"
418421
else:

trinity/manager/synchronizer.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import os
5+
import shutil
56
from collections import defaultdict
67
from typing import Dict, List, Optional, Tuple, Union
78

@@ -95,13 +96,14 @@ async def _find_verl_latest_state_dict(self) -> None:
9596
)
9697
while True:
9798
if os.path.exists(local_latest_state_dict_iteration):
99+
current_model_version = self.model_version
98100
try:
99101
with open(local_latest_state_dict_iteration, "r") as f:
100102
latest_model_version = int(f.read().strip())
101103
except (IOError, ValueError) as e:
102104
self.logger.warning(f"Failed to read or parse state dict iteration file: {e}")
103105
continue
104-
if latest_model_version > self.model_version:
106+
if latest_model_version > current_model_version:
105107
self.logger.info(
106108
f"Synchronizer has found a new model state dict at step {latest_model_version}."
107109
)
@@ -119,8 +121,22 @@ async def _find_verl_latest_state_dict(self) -> None:
119121
f"Synchronizer has loaded model state dict from checkpoint {latest_model_version}."
120122
)
121123
await self.set_model_state_dict(model_state_dict, latest_model_version)
124+
# remove the previous checkpoints to save disk space
125+
await self._remove_previous_state_dict(current_model_version)
122126
await asyncio.sleep(1)
123127

128+
async def _remove_previous_state_dict(self, previous_model_version: int) -> None:
129+
previous_state_dict_dir = os.path.join(
130+
self.config.checkpoint_job_dir, f"global_step_{previous_model_version}"
131+
)
132+
if os.path.exists(previous_state_dict_dir):
133+
# check if it's a full checkpoint, only remove checkpoints for sync
134+
if not os.path.exists(os.path.join(previous_state_dict_dir, ".full_checkpoint")):
135+
self.logger.info(
136+
f"Removing previous checkpoint for sync at step {previous_model_version}."
137+
)
138+
shutil.rmtree(previous_state_dict_dir, ignore_errors=True)
139+
124140
async def _find_tinker_latest_state_dict(self) -> None:
125141
default_local_dir = self.config.checkpoint_job_dir
126142
local_latest_state_dict_iteration = os.path.join(
@@ -320,17 +336,15 @@ async def get_latest_model_version(self) -> int:
320336
async with self._ready_condition:
321337
return self.model_version
322338

323-
async def ready_to_nccl_sync(
324-
self, module: str, trainer_step: Optional[int] = None
325-
) -> Union[int, None]:
339+
async def ready_to_nccl_sync(self, module: str, trainer_step: int) -> Union[int, None]:
326340
"""
327341
Prepare for NCCL-based synchronization between modules.
328342
329343
Only supports one explorer currently.
330344
331345
Args:
332346
module: Either 'trainer' or 'explorer'.
333-
trainer_step: Optional step number from the trainer.
347+
trainer_step: Step number from the trainer.
334348
335349
Returns:
336350
The model version if both sides are ready; otherwise None.

trinity/trainer/tinker_trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,15 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
282282
f"global_step_{self.train_step_num}",
283283
)
284284
os.makedirs(local_path, exist_ok=True)
285+
286+
# save a flag to indicate this is a full checkpoint dir
287+
# make sure this flag is created before notifying the synchronizer
288+
# to avoid the synchronizer recognizing it as a state_dict-only checkpoint
289+
# TODO: use a better way to indicate full checkpoint
290+
flag_path = os.path.join(local_path, ".full_checkpoint")
291+
with open(flag_path, "w") as f:
292+
f.write("")
293+
285294
remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt")
286295
with open(remote_checkpoint_path, "w") as f:
287296
f.write(self.latest_remote_checkpoint_path)

0 commit comments

Comments
 (0)