@@ -98,79 +98,62 @@ def upload_state_dict(self, global_step: int):
9898 state_dict = self .model .state_dict ()
9999 self ._upload_state_dict (state_dict , global_step )
100100
101- def _save_model (self , local_path , global_step ):
102- model_path = os .path .join (
103- local_path , f"model_world_size_{ self .world_size } _rank_{ self .rank } .pt"
101+ def _save_with_thread (
102+ self ,
103+ obj ,
104+ local_path : str ,
105+ prefix : str ,
106+ thread_name : str ,
107+ global_step : int ,
108+ is_state_dict : bool = False ,
109+ ):
110+ path = os .path .join (
111+ local_path , f"{ prefix } _world_size_{ self .world_size } _rank_{ self .rank } .pt"
104112 )
105- model_state_dict = self . model . state_dict ( )
106- if self . _model_state_dict_thread is not None :
107- self . _model_state_dict_thread .join ()
113+ thread = getattr ( self , thread_name )
114+ if thread is not None :
115+ thread .join ()
108116
109- def _save_model_state_dict ():
110- torch .save (model_state_dict , model_path )
117+ def _save ():
118+ torch .save (obj , path )
111119 log_with_rank (
112- f"Saved model to { os .path .abspath (model_path )} " ,
120+ f"Saved { prefix } to { os .path .abspath (path )} " ,
113121 rank = self .rank ,
114122 logger = logger ,
115123 )
116- ray .get (self .checkpoint_monitor .notify_finished .remote (global_step , True ))
124+ ray .get (self .checkpoint_monitor .notify_finished .remote (global_step , is_state_dict ))
125+
126+ thread = threading .Thread (
127+ target = _save ,
128+ )
129+ thread .start ()
130+ setattr (self , thread_name , thread )
117131
118- self ._model_state_dict_thread = threading .Thread (
119- target = _save_model_state_dict ,
132+ def _save_model (self , local_path , global_step ):
133+ model_state_dict = self .model .state_dict ()
134+ self ._save_with_thread (
135+ model_state_dict , local_path , "model" , "_model_state_dict_thread" , global_step , True
120136 )
121- self ._model_state_dict_thread .start ()
122137
123138 self .previous_state_dict_step = global_step
124139
125140 def _save_optimizer (self , local_path , global_step ):
126- optim_path = os .path .join (
127- local_path , f"optim_world_size_{ self .world_size } _rank_{ self .rank } .pt"
128- )
129141 optimizer_state_dict = self .optimizer .state_dict ()
130- if self ._optimizer_state_dict_thread is not None :
131- self ._optimizer_state_dict_thread .join ()
132-
133- def _save_optimizer_state_dict ():
134- torch .save (optimizer_state_dict , optim_path )
135- log_with_rank (
136- f"Saved optim to { os .path .abspath (optim_path )} " ,
137- rank = self .rank ,
138- logger = logger ,
139- )
140- ray .get (self .checkpoint_monitor .notify_finished .remote (global_step ))
141-
142- self ._optimizer_state_dict_thread = threading .Thread (
143- target = _save_optimizer_state_dict ,
142+ self ._save_with_thread (
143+ optimizer_state_dict , local_path , "optim" , "_optimizer_state_dict_thread" , global_step
144144 )
145- self ._optimizer_state_dict_thread .start ()
146145
147146 def _save_extra_state (self , local_path , global_step ):
148- extra_path = os .path .join (
149- local_path , f"extra_state_world_size_{ self .world_size } _rank_{ self .rank } .pt"
150- )
151147 lr_scheduler_state_dict = (
152148 self .lr_scheduler .state_dict () if self .lr_scheduler is not None else None
153149 )
154150 extra_state_dict = {
155151 "lr_scheduler" : lr_scheduler_state_dict ,
156152 "rng" : self .get_rng_state (),
157153 }
158- if self ._extra_state_dict_thread is not None :
159- self ._extra_state_dict_thread .join ()
160-
161- def _save_extra_state_dict ():
162- torch .save (extra_state_dict , extra_path )
163- log_with_rank (
164- f"Saved extra_state to { os .path .abspath (extra_path )} " ,
165- rank = self .rank ,
166- logger = logger ,
167- )
168- ray .get (self .checkpoint_monitor .notify_finished .remote (global_step ))
169-
170- self ._extra_state_dict_thread = threading .Thread (
171- target = _save_extra_state_dict ,
154+ self ._save_with_thread (
155+ extra_state_dict , local_path , "extra_state" , "_extra_state_dict_thread" , global_step
172156 )
173- self ._extra_state_dict_thread .start ()
174157
175158 def save_state_dict ( # noqa: C901
176159 self ,
@@ -200,7 +183,11 @@ def save_state_dict( # noqa: C901
200183 self .model , StateDictType .SHARDED_STATE_DICT , state_dict_cfg , optim_cfg
201184 ):
202185 self ._save_model (local_path , global_step )
203- ray .get (self .checkpoint_monitor .register_state_dict_save_count .remote (global_step , 1 ))
186+ ray .get (
187+ self .checkpoint_monitor .register_thread_count .remote (
188+ global_step , state_dict_thread_count = 1
189+ )
190+ )
204191
205192 def save_checkpoint ( # noqa: C901
206193 self ,
@@ -259,7 +246,7 @@ def save_checkpoint( # noqa: C901
259246 ), "optimizer must be provided when checkpoint_contents.save includes ['optimizer']"
260247
261248 state_dict_thread_count = 0
262- other_thread_count = 0
249+ checkpoint_thread_count = 0
263250
264251 # every rank will save its own model and optim shard
265252 state_dict_cfg = ShardedStateDictConfig (offload_to_cpu = True if is_cuda_available else False )
@@ -275,11 +262,11 @@ def save_checkpoint( # noqa: C901
275262 self ._save_model (local_path , global_step )
276263
277264 if self .should_save_optimizer :
278- other_thread_count += 1
265+ checkpoint_thread_count += 1
279266 self ._save_optimizer (local_path , global_step )
280267
281268 if self .should_save_extra :
282- other_thread_count += 1
269+ checkpoint_thread_count += 1
283270 self ._save_extra_state (local_path , global_step )
284271
285272 if self .rank == 0 :
@@ -333,7 +320,7 @@ def save_checkpoint( # noqa: C901
333320 state_dict = get_fsdp_full_state_dict (self .model , offload_to_cpu = True , rank0_only = True )
334321
335322 if self .rank == 0 :
336- other_thread_count += 1
323+ checkpoint_thread_count += 1
337324 hf_local_path = os .path .join (local_path , "huggingface" )
338325 os .makedirs (hf_local_path , exist_ok = True )
339326
@@ -390,8 +377,10 @@ def _save_model():
390377 torch .distributed .barrier ()
391378
392379 ray .get (
393- self .checkpoint_monitor .register_checkpoint_save_count .remote (
394- global_step , state_dict_thread_count , other_thread_count
380+ self .checkpoint_monitor .register_thread_count .remote (
381+ global_step ,
382+ state_dict_thread_count = state_dict_thread_count ,
383+ checkpoint_thread_count = checkpoint_thread_count ,
395384 )
396385 )
397386 self .previous_saved_paths .append (local_path )
0 commit comments