4646)
4747from verl .utils .logger import log_with_rank
4848
49- from trinity .common .constants import SyncMethod
5049from trinity .manager .synchronizer import Synchronizer
50+ from trinity .trainer .verl_trainer import CheckpointMonitor
5151
5252
5353class FSDPCheckpointManager (OldFSDPCheckpointManager ):
@@ -60,15 +60,12 @@ class FSDPCheckpointManager(OldFSDPCheckpointManager):
6060 This class is useful in distributed training scenarios where synchronization and non-blocking I/O are important.
6161 """
6262
63- def __init__ (self , * args , ** kwargs ):
63+ def __init__ (self , * args , ray_namespace : str = "" , ** kwargs ):
6464 super ().__init__ (* args , ** kwargs )
65- config = kwargs .pop ("config" , None )
66- self .synchronizer_config = config
67- if config is not None :
68- # Retrieve the remote Synchronizer actor using the provided namespace
69- self .synchronizer = Synchronizer .get_actor (namespace = config .ray_namespace )
70- else :
71- self .synchronizer = None
65+ self .synchronizer = Synchronizer .get_actor (namespace = ray_namespace )
66+ self .checkpoint_monitor = CheckpointMonitor .get_actor (
67+ namespace = ray_namespace ,
68+ )
7269
7370 # Threads for asynchronous saving of different components
7471 self ._model_state_dict_thread = None
@@ -77,21 +74,6 @@ def __init__(self, *args, **kwargs):
7774 self ._save_model_thread = None
7875 self .previous_state_dict_step = None
7976
80- def _notify_synchronizer_with_step_num (self , global_step ):
81- """
82- Notifies the Synchronizer actor about the current training step number,
83- used when SyncMethod is CHECKPOINT.
84-
85- Args:
86- global_step (int): The current global training step.
87- """
88- if getattr (self .synchronizer_config , "sync_method" , None ) == SyncMethod .CHECKPOINT :
89- ray .get (
90- self .synchronizer .set_model_state_dict_with_step_num .remote (
91- global_step , self .world_size
92- )
93- )
94-
9577 def _upload_state_dict (self , state_dict : Union [dict , None ], global_step : int ):
9678 """
9779 Internal method to upload a state dict to the Synchronizer actor.
@@ -131,14 +113,16 @@ def _save_model_state_dict():
131113 rank = self .rank ,
132114 logger = logger ,
133115 )
134- self ._notify_synchronizer_with_step_num (global_step )
116+ ray . get ( self .checkpoint_monitor . notify_finished . remote (global_step , True ) )
135117
136118 self ._model_state_dict_thread = threading .Thread (
137119 target = _save_model_state_dict ,
138120 )
139121 self ._model_state_dict_thread .start ()
140122
141- def _save_optimizer (self , local_path ):
123+ self .previous_state_dict_step = global_step
124+
125+ def _save_optimizer (self , local_path , global_step ):
142126 optim_path = os .path .join (
143127 local_path , f"optim_world_size_{ self .world_size } _rank_{ self .rank } .pt"
144128 )
@@ -153,13 +137,14 @@ def _save_optimizer_state_dict():
153137 rank = self .rank ,
154138 logger = logger ,
155139 )
140+ ray .get (self .checkpoint_monitor .notify_finished .remote (global_step ))
156141
157142 self ._optimizer_state_dict_thread = threading .Thread (
158143 target = _save_optimizer_state_dict ,
159144 )
160145 self ._optimizer_state_dict_thread .start ()
161146
162- def _save_extra_state (self , local_path ):
147+ def _save_extra_state (self , local_path , global_step ):
163148 extra_path = os .path .join (
164149 local_path , f"extra_state_world_size_{ self .world_size } _rank_{ self .rank } .pt"
165150 )
@@ -180,6 +165,7 @@ def _save_extra_state_dict():
180165 rank = self .rank ,
181166 logger = logger ,
182167 )
168+ ray .get (self .checkpoint_monitor .notify_finished .remote (global_step ))
183169
184170 self ._extra_state_dict_thread = threading .Thread (
185171 target = _save_extra_state_dict ,
@@ -193,11 +179,12 @@ def save_state_dict( # noqa: C901
193179 global_step : int = 0 ,
194180 ):
195181 if self .previous_state_dict_step is None :
182+ # First sync in trainer.prepare
196183 self .previous_state_dict_step = global_step
197184 self ._upload_state_dict (None , global_step )
198185 return
199186 elif self .previous_state_dict_step == global_step :
200- self . _notify_synchronizer_with_step_num ( global_step )
187+ # No need to save for sync again
201188 return
202189 if local_path is None :
203190 return
@@ -213,8 +200,7 @@ def save_state_dict( # noqa: C901
213200 self .model , StateDictType .SHARDED_STATE_DICT , state_dict_cfg , optim_cfg
214201 ):
215202 self ._save_model (local_path , global_step )
216-
217- self .previous_state_dict_step = global_step
203+ ray .get (self .checkpoint_monitor .register_state_dict_save_count .remote (global_step , 1 ))
218204
219205 def save_checkpoint ( # noqa: C901
220206 self ,
@@ -239,12 +225,14 @@ def save_checkpoint( # noqa: C901
239225 hdfs_path (str, optional): HDFS path for saving the checkpoint (not implemented here).
240226 global_step (int): Current training step.
241227 max_ckpt_to_keep (int, optional): Maximum number of checkpoints to keep locally.
242- model_state_dict_only (bool): Whether to only save the model state dict (no optimizer, etc.).
243228 save_as_hf (bool): Whether to force save the model in Hugging Face format.
244229 """
245230 if local_path is None :
246231 return
247232
233+ # record the previous global step
234+ self .previous_global_step = global_step
235+
248236 # remove previous local_path, only rank 0 should do this
249237 if (
250238 self .rank == 0
@@ -270,6 +258,9 @@ def save_checkpoint( # noqa: C901
270258 self .optimizer is not None
271259 ), "optimizer must be provided when checkpoint_contents.save includes ['optimizer']"
272260
261+ state_dict_thread_count = 0
262+ other_thread_count = 0
263+
273264 # every rank will save its own model and optim shard
274265 state_dict_cfg = ShardedStateDictConfig (offload_to_cpu = True if is_cuda_available else False )
275266 optim_cfg = ShardedOptimStateDictConfig (offload_to_cpu = True if is_cuda_available else False )
@@ -279,16 +270,17 @@ def save_checkpoint( # noqa: C901
279270 self .model , StateDictType .SHARDED_STATE_DICT , state_dict_cfg , optim_cfg
280271 ):
281272 if self .should_save_model :
282- if self .previous_state_dict_step == global_step :
283- self ._notify_synchronizer_with_step_num (global_step )
284- else :
273+ if self .previous_state_dict_step != global_step :
274+ state_dict_thread_count += 1
285275 self ._save_model (local_path , global_step )
286276
287277 if self .should_save_optimizer :
288- self ._save_optimizer (local_path )
278+ other_thread_count += 1
279+ self ._save_optimizer (local_path , global_step )
289280
290281 if self .should_save_extra :
291- self ._save_extra_state (local_path )
282+ other_thread_count += 1
283+ self ._save_extra_state (local_path , global_step )
292284
293285 if self .rank == 0 :
294286 # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether
@@ -341,6 +333,7 @@ def save_checkpoint( # noqa: C901
341333 state_dict = get_fsdp_full_state_dict (self .model , offload_to_cpu = True , rank0_only = True )
342334
343335 if self .rank == 0 :
336+ other_thread_count += 1
344337 hf_local_path = os .path .join (local_path , "huggingface" )
345338 os .makedirs (hf_local_path , exist_ok = True )
346339
@@ -386,19 +379,21 @@ def _save_model():
386379 logger = logger ,
387380 log_only_rank_0 = True ,
388381 )
382+ ray .get (self .checkpoint_monitor .notify_finished .remote (global_step ))
389383
390384 self ._save_model_thread = threading .Thread (
391385 target = _save_model ,
392386 )
393387 self ._save_model_thread .start ()
394- self .processing_class .save_pretrained (hf_local_path )
395388
396389 # wait for rank0 to dump hf_model to local
397390 torch .distributed .barrier ()
398391
399- # record the previous global step
400- self .previous_global_step = global_step
401- self .previous_state_dict_step = global_step
392+ ray .get (
393+ self .checkpoint_monitor .register_checkpoint_save_count .remote (
394+ global_step , state_dict_thread_count , other_thread_count
395+ )
396+ )
402397 self .previous_saved_paths .append (local_path )
403398
404399 def wait_on_save_thread (self ) -> None :
0 commit comments