Skip to content

Commit be81e52

Browse files
authored
[Trainer] update trainer (#2378)
1 parent 7b47912 commit be81e52

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

paddleformers/trainer/trainer.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2616,8 +2616,8 @@ def save_model(
26162616
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
26172617
else:
26182618
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
2619-
os.makedirs(signal_dir, exist_ok=True)
26202619
if self.is_in_train:
2620+
os.makedirs(signal_dir, exist_ok=True)
26212621
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
26222622
paddle.save(global_rank, os.path.join(signal_dir, f".model_weight.done.{global_rank}"))
26232623

@@ -2630,14 +2630,6 @@ def save_model(
26302630
):
26312631
# For ckpt integrity
26322632
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))
2633-
if (
2634-
self.args.unified_checkpoint
2635-
and "async_save" in self.args.unified_checkpoint_config
2636-
and not self.is_in_train
2637-
):
2638-
os.makedirs(signal_dir, exist_ok=True)
2639-
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
2640-
paddle.save(self.state.global_step, os.path.join(signal_dir, f".model_weight.done.{global_rank}"))
26412633

26422634
def _filter_moe_no_sync_optimizer_params(self):
26432635
"""
@@ -2850,7 +2842,7 @@ def _save_checkpoint(self, model, metrics=None):
28502842
need_to_rotate_checkpoints = self.args.should_save_model_state
28512843

28522844
# Delete only by one process
2853-
need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank == 0
2845+
need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank in [0, -1]
28542846
if need_to_rotate_checkpoints:
28552847
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
28562848
self._rotate_checkpoints(use_mtime=True, output_dir=run_signal_dir)
@@ -2943,8 +2935,9 @@ def _save(
29432935
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
29442936
if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]:
29452937
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])
2946-
os.makedirs(signal_dir, exist_ok=True)
2947-
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")
2938+
if self.is_in_train:
2939+
os.makedirs(signal_dir, exist_ok=True)
2940+
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")
29482941

29492942
# Save a trained model and configuration using `save_pretrained()`.
29502943
# They can then be reloaded using `from_pretrained()`
@@ -2980,7 +2973,7 @@ def _save(
29802973
# backup and remove unified_checkpoint_config for not trine stage
29812974
if not self.is_in_train:
29822975
self.args.unified_checkpoint_config = []
2983-
2976+
signal_dir = None
29842977
self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir, signal_dir)
29852978

29862979
# recover unified_checkpoint_config for not trine stage

0 commit comments

Comments
 (0)