@@ -2616,8 +2616,8 @@ def save_model(
26162616 self ._save (output_dir = output_dir , merge_tensor_parallel = merge_tensor_parallel )
26172617 else :
26182618 if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
2619- os .makedirs (signal_dir , exist_ok = True )
26202619 if self .is_in_train :
2620+ os .makedirs (signal_dir , exist_ok = True )
26212621 global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
26222622 paddle .save (global_rank , os .path .join (signal_dir , f".model_weight.done.{ global_rank } " ))
26232623
@@ -2630,14 +2630,6 @@ def save_model(
26302630 ):
26312631 # For ckpt integrity
26322632 paddle .save (self .state .global_step , os .path .join (output_dir , ".model_done" ))
2633- if (
2634- self .args .unified_checkpoint
2635- and "async_save" in self .args .unified_checkpoint_config
2636- and not self .is_in_train
2637- ):
2638- os .makedirs (signal_dir , exist_ok = True )
2639- global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
2640- paddle .save (self .state .global_step , os .path .join (signal_dir , f".model_weight.done.{ global_rank } " ))
26412633
26422634 def _filter_moe_no_sync_optimizer_params (self ):
26432635 """
@@ -2850,7 +2842,7 @@ def _save_checkpoint(self, model, metrics=None):
28502842 need_to_rotate_checkpoints = self .args .should_save_model_state
28512843
28522844 # Delete only by one process
2853- need_to_rotate_checkpoints = need_to_rotate_checkpoints and self .args .local_rank == 0
2845+ need_to_rotate_checkpoints = need_to_rotate_checkpoints and self .args .local_rank in [ 0 , - 1 ]
28542846 if need_to_rotate_checkpoints :
28552847 self ._rotate_checkpoints (use_mtime = True , output_dir = run_dir )
28562848 self ._rotate_checkpoints (use_mtime = True , output_dir = run_signal_dir )
@@ -2943,8 +2935,9 @@ def _save(
29432935 if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
29442936 if PREFIX_CHECKPOINT_DIR in os .path .split (output_dir )[- 1 ]:
29452937 signal_dir = os .path .join (signal_dir , os .path .split (output_dir )[- 1 ])
2946- os .makedirs (signal_dir , exist_ok = True )
2947- logger .info (f"Saving model checkpoint finish signal to { signal_dir } " )
2938+ if self .is_in_train :
2939+ os .makedirs (signal_dir , exist_ok = True )
2940+ logger .info (f"Saving model checkpoint finish signal to { signal_dir } " )
29482941
29492942 # Save a trained model and configuration using `save_pretrained()`.
29502943 # They can then be reloaded using `from_pretrained()`
@@ -2980,7 +2973,7 @@ def _save(
29802973 # backup and remove unified_checkpoint_config for not trine stage
29812974 if not self .is_in_train :
29822975 self .args .unified_checkpoint_config = []
2983-
2976+ signal_dir = None
29842977 self .unified_checkpoint_handler .save_unified_checkpoint (self .model , self .optimizer , output_dir , signal_dir )
29852978
29862979 # recover unified_checkpoint_config for not trine stage
0 commit comments