Skip to content

Commit cdfd70e

Browse files
committed
add CheckpointMonitor and refactor checkpoint manager
1 parent 3490c2c commit cdfd70e

File tree

6 files changed

+169
-86
lines changed

6 files changed

+169
-86
lines changed

trinity/manager/synchronizer.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import ray
99

1010
from trinity.common.config import Config
11-
from trinity.common.constants import RunningStatus
11+
from trinity.common.constants import RunningStatus, SyncMethod
1212
from trinity.common.models.utils import (
1313
get_checkpoint_dir_with_step_num,
1414
load_state_dict,
@@ -43,6 +43,8 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle):
4343
self._modules = {module_ref}
4444
self._modules_lock = asyncio.Lock()
4545
asyncio.create_task(self._check_modules())
46+
if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT:
47+
asyncio.create_task(self._find_latest_state_dict())
4648

4749
async def add_module(self, module_ref: ray.actor.ActorHandle) -> None:
4850
"""Adds a module to be tracked by the synchronizer.
@@ -72,6 +74,32 @@ async def _check_modules(self) -> None:
7274
except Exception:
7375
pass
7476

77+
async def _find_latest_state_dict(self) -> None:
78+
assert self.config.trainer.trainer_type == "verl"
79+
default_local_dir = self.config.trainer.trainer_config.trainer.default_local_dir
80+
local_latest_state_dict_iteration = os.path.join(
81+
default_local_dir, "latest_state_dict_iteration.txt"
82+
)
83+
while True:
84+
if os.path.exists(local_latest_state_dict_iteration):
85+
with open(local_latest_state_dict_iteration, "r") as f:
86+
latest_model_version = int(f.read().strip())
87+
if latest_model_version > self.model_version:
88+
self.logger.info(
89+
f"Synchronizer has found a new model state dict at step {latest_model_version}."
90+
)
91+
model_state_dict = load_state_dict(
92+
os.path.join(
93+
default_local_dir, f"global_step_{latest_model_version}", "actor"
94+
),
95+
self.config.trainer,
96+
)
97+
self.logger.info(
98+
f"Synchronizer has loaded model state dict from checkpoint {self.model_version}."
99+
)
100+
await self.set_model_state_dict(model_state_dict, latest_model_version)
101+
await asyncio.sleep(1)
102+
75103
async def set_trainer_status(self, status: RunningStatus):
76104
"""Update the status of the trainer."""
77105
async with self._ready_condition:

trinity/trainer/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = Fa
189189
current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size,
190190
current_step=self.train_step_num,
191191
)
192-
self.logger.info(f"Checkpoint at step {self.train_step_num} saved.")
193192
return metrics
194193

195194
async def shutdown(self) -> None:

trinity/trainer/verl/fsdp_checkpoint_manager.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
)
4747
from verl.utils.logger import log_with_rank
4848

49-
from trinity.common.constants import SyncMethod
5049
from trinity.manager.synchronizer import Synchronizer
50+
from trinity.trainer.verl_trainer import CheckpointMonitor
5151

5252

5353
class FSDPCheckpointManager(OldFSDPCheckpointManager):
@@ -60,15 +60,12 @@ class FSDPCheckpointManager(OldFSDPCheckpointManager):
6060
This class is useful in distributed training scenarios where synchronization and non-blocking I/O are important.
6161
"""
6262

63-
def __init__(self, *args, **kwargs):
63+
def __init__(self, *args, ray_namespace: str = "", **kwargs):
6464
super().__init__(*args, **kwargs)
65-
config = kwargs.pop("config", None)
66-
self.synchronizer_config = config
67-
if config is not None:
68-
# Retrieve the remote Synchronizer actor using the provided namespace
69-
self.synchronizer = Synchronizer.get_actor(namespace=config.ray_namespace)
70-
else:
71-
self.synchronizer = None
65+
self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace)
66+
self.checkpoint_monitor = CheckpointMonitor.get_actor(
67+
namespace=ray_namespace,
68+
)
7269

7370
# Threads for asynchronous saving of different components
7471
self._model_state_dict_thread = None
@@ -77,21 +74,6 @@ def __init__(self, *args, **kwargs):
7774
self._save_model_thread = None
7875
self.previous_state_dict_step = None
7976

80-
def _notify_synchronizer_with_step_num(self, global_step):
81-
"""
82-
Notifies the Synchronizer actor about the current training step number,
83-
used when SyncMethod is CHECKPOINT.
84-
85-
Args:
86-
global_step (int): The current global training step.
87-
"""
88-
if getattr(self.synchronizer_config, "sync_method", None) == SyncMethod.CHECKPOINT:
89-
ray.get(
90-
self.synchronizer.set_model_state_dict_with_step_num.remote(
91-
global_step, self.world_size
92-
)
93-
)
94-
9577
def _upload_state_dict(self, state_dict: Union[dict, None], global_step: int):
9678
"""
9779
Internal method to upload a state dict to the Synchronizer actor.
@@ -131,14 +113,16 @@ def _save_model_state_dict():
131113
rank=self.rank,
132114
logger=logger,
133115
)
134-
self._notify_synchronizer_with_step_num(global_step)
116+
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step, True))
135117

136118
self._model_state_dict_thread = threading.Thread(
137119
target=_save_model_state_dict,
138120
)
139121
self._model_state_dict_thread.start()
140122

141-
def _save_optimizer(self, local_path):
123+
self.previous_state_dict_step = global_step
124+
125+
def _save_optimizer(self, local_path, global_step):
142126
optim_path = os.path.join(
143127
local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt"
144128
)
@@ -153,13 +137,14 @@ def _save_optimizer_state_dict():
153137
rank=self.rank,
154138
logger=logger,
155139
)
140+
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step))
156141

157142
self._optimizer_state_dict_thread = threading.Thread(
158143
target=_save_optimizer_state_dict,
159144
)
160145
self._optimizer_state_dict_thread.start()
161146

162-
def _save_extra_state(self, local_path):
147+
def _save_extra_state(self, local_path, global_step):
163148
extra_path = os.path.join(
164149
local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt"
165150
)
@@ -180,6 +165,7 @@ def _save_extra_state_dict():
180165
rank=self.rank,
181166
logger=logger,
182167
)
168+
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step))
183169

184170
self._extra_state_dict_thread = threading.Thread(
185171
target=_save_extra_state_dict,
@@ -193,11 +179,12 @@ def save_state_dict( # noqa: C901
193179
global_step: int = 0,
194180
):
195181
if self.previous_state_dict_step is None:
182+
# First sync in trainer.prepare
196183
self.previous_state_dict_step = global_step
197184
self._upload_state_dict(None, global_step)
198185
return
199186
elif self.previous_state_dict_step == global_step:
200-
self._notify_synchronizer_with_step_num(global_step)
187+
# No need to save for sync again
201188
return
202189
if local_path is None:
203190
return
@@ -213,8 +200,7 @@ def save_state_dict( # noqa: C901
213200
self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg
214201
):
215202
self._save_model(local_path, global_step)
216-
217-
self.previous_state_dict_step = global_step
203+
ray.get(self.checkpoint_monitor.register_state_dict_save_count.remote(global_step, 1))
218204

219205
def save_checkpoint( # noqa: C901
220206
self,
@@ -239,12 +225,14 @@ def save_checkpoint( # noqa: C901
239225
hdfs_path (str, optional): HDFS path for saving the checkpoint (not implemented here).
240226
global_step (int): Current training step.
241227
max_ckpt_to_keep (int, optional): Maximum number of checkpoints to keep locally.
242-
model_state_dict_only (bool): Whether to only save the model state dict (no optimizer, etc.).
243228
save_as_hf (bool): Whether to force save the model in Hugging Face format.
244229
"""
245230
if local_path is None:
246231
return
247232

233+
# record the previous global step
234+
self.previous_global_step = global_step
235+
248236
# remove previous local_path, only rank 0 should do this
249237
if (
250238
self.rank == 0
@@ -270,6 +258,9 @@ def save_checkpoint( # noqa: C901
270258
self.optimizer is not None
271259
), "optimizer must be provided when checkpoint_contents.save includes ['optimizer']"
272260

261+
state_dict_thread_count = 0
262+
other_thread_count = 0
263+
273264
# every rank will save its own model and optim shard
274265
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False)
275266
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False)
@@ -279,16 +270,17 @@ def save_checkpoint( # noqa: C901
279270
self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg
280271
):
281272
if self.should_save_model:
282-
if self.previous_state_dict_step == global_step:
283-
self._notify_synchronizer_with_step_num(global_step)
284-
else:
273+
if self.previous_state_dict_step != global_step:
274+
state_dict_thread_count += 1
285275
self._save_model(local_path, global_step)
286276

287277
if self.should_save_optimizer:
288-
self._save_optimizer(local_path)
278+
other_thread_count += 1
279+
self._save_optimizer(local_path, global_step)
289280

290281
if self.should_save_extra:
291-
self._save_extra_state(local_path)
282+
other_thread_count += 1
283+
self._save_extra_state(local_path, global_step)
292284

293285
if self.rank == 0:
294286
# Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether
@@ -341,6 +333,7 @@ def save_checkpoint( # noqa: C901
341333
state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True)
342334

343335
if self.rank == 0:
336+
other_thread_count += 1
344337
hf_local_path = os.path.join(local_path, "huggingface")
345338
os.makedirs(hf_local_path, exist_ok=True)
346339

@@ -386,19 +379,21 @@ def _save_model():
386379
logger=logger,
387380
log_only_rank_0=True,
388381
)
382+
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step))
389383

390384
self._save_model_thread = threading.Thread(
391385
target=_save_model,
392386
)
393387
self._save_model_thread.start()
394-
self.processing_class.save_pretrained(hf_local_path)
395388

396389
# wait for rank0 to dump hf_model to local
397390
torch.distributed.barrier()
398391

399-
# record the previous global step
400-
self.previous_global_step = global_step
401-
self.previous_state_dict_step = global_step
392+
ray.get(
393+
self.checkpoint_monitor.register_checkpoint_save_count.remote(
394+
global_step, state_dict_thread_count, other_thread_count
395+
)
396+
)
402397
self.previous_saved_paths.append(local_path)
403398

404399
def wait_on_save_thread(self) -> None:

trinity/trainer/verl/fsdp_workers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ def init_model(self):
569569
lr_scheduler=None,
570570
processing_class=self.processor if self.processor is not None else self.tokenizer,
571571
checkpoint_config=self.config.ref.checkpoint,
572+
ray_namespace=self.config.synchronizer.ray_namespace,
572573
)
573574

574575
if self._is_actor:
@@ -579,7 +580,7 @@ def init_model(self):
579580
lr_scheduler=self.actor_lr_scheduler,
580581
processing_class=self.processor if self.processor is not None else self.tokenizer,
581582
checkpoint_config=self.config.actor.checkpoint,
582-
config=self.config.synchronizer,
583+
ray_namespace=self.config.synchronizer.ray_namespace,
583584
)
584585

585586
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
@@ -870,7 +871,6 @@ def save_checkpoint(
870871
hdfs_path=None,
871872
global_step=0,
872873
max_ckpt_to_keep=None,
873-
model_state_dict_only=False,
874874
save_as_hf: bool = False,
875875
):
876876
# only support save and load ckpt for actor
@@ -882,7 +882,6 @@ def save_checkpoint(
882882
hdfs_path=hdfs_path,
883883
global_step=global_step,
884884
max_ckpt_to_keep=max_ckpt_to_keep,
885-
model_state_dict_only=model_state_dict_only,
886885
save_as_hf=save_as_hf,
887886
)
888887
dist.barrier()
@@ -1233,6 +1232,7 @@ def init_model(self):
12331232
lr_scheduler=self.critic_lr_scheduler,
12341233
processing_class=self.processor if self.processor is not None else self.tokenizer,
12351234
checkpoint_config=self.config.checkpoint,
1235+
ray_namespace=self.config.synchronizer.ray_namespace,
12361236
)
12371237

12381238
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)

trinity/trainer/verl/megatron_checkpoint_manager.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@
3838
get_transformer_config_checkpoint_path,
3939
)
4040

41-
from trinity.common.config import SynchronizerConfig
42-
from trinity.common.constants import SyncMethod
4341
from trinity.manager.synchronizer import Synchronizer
42+
from trinity.trainer.verl_trainer import CheckpointMonitor
4443

4544

4645
class MegatronCheckpointManager(OldMegatronCheckpointManager):
@@ -53,34 +52,17 @@ class MegatronCheckpointManager(OldMegatronCheckpointManager):
5352
def __init__(
5453
self,
5554
*args,
56-
sync_config: SynchronizerConfig = None,
55+
ray_namespace: str = "",
5756
**kwargs,
5857
):
5958
super().__init__(
6059
*args,
6160
**kwargs,
6261
)
63-
self.synchronizer_config = sync_config
64-
if sync_config is not None:
65-
# Retrieve the remote Synchronizer actor using the provided namespace
66-
self.synchronizer = Synchronizer.get_actor(namespace=sync_config.ray_namespace)
67-
else:
68-
self.synchronizer = None
69-
70-
def _notify_synchronizer_with_step_num(self, global_step):
71-
"""
72-
Notifies the Synchronizer actor about the current training step number,
73-
used when SyncMethod is CHECKPOINT.
74-
75-
Args:
76-
global_step (int): The current global training step.
77-
"""
78-
if getattr(self.synchronizer_config, "sync_method", None) == SyncMethod.CHECKPOINT:
79-
ray.get(
80-
self.synchronizer.set_model_state_dict_with_step_num.remote(
81-
global_step, self.world_size
82-
)
83-
)
62+
self.synchronizer = Synchronizer.get_actor(namespace=ray_namespace)
63+
self.checkpoint_monitor = CheckpointMonitor.get_actor(
64+
namespace=ray_namespace,
65+
)
8466

8567
def save_checkpoint( # noqa: C901
8668
self,
@@ -260,14 +242,16 @@ def save_checkpoint( # noqa: C901
260242
log_only_rank_0=True,
261243
)
262244

245+
ray.get(self.checkpoint_monitor.register_checkpoint_save_count.remote(global_step, 1, 0))
246+
263247
def finalize_save_fn():
264248
# Rank 0 uploads checkpoint to HDFS if hdfs_path is provided
265249
log_with_rank(
266250
f"Dist checkpointing save completed for {dist_checkpoint_path}",
267251
rank=self.rank,
268252
logger=logger,
269253
)
270-
self._notify_synchronizer_with_step_num(global_step)
254+
ray.get(self.checkpoint_monitor.notify_finished.remote(global_step, True))
271255
if self.rank == 0:
272256
if hdfs_path is not None:
273257
log_with_rank(

0 commit comments

Comments
 (0)