Skip to content

Commit fbe1fae

Browse files
authored
[Auto Parallel] Support save model file (#8927)
1 parent 0136312 commit fbe1fae

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
MODEL_NAME = "model"
5050
OPTIMIZER_NAME = "optimizer"
5151
DIST_CKPT_PATH = "dist_ckpt"
52+
DIST_MODEL_PATH = "dist_model"
5253
FREE_SVAE_LOAD_KEY_PATTERNS = ["learning_rate_", "gradient_merge_", "@GRAD@MERG", "eager_tmp"]
5354

5455

@@ -552,6 +553,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
552553
with _exec_mode_guard("dynamic"):
553554
super()._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, **kwargs)
554555

556+
def _save_model(self):
557+
if not self.args.to_static:
558+
return
559+
with _exec_mode_guard("static"):
560+
output_dir = f"{self.args.output_dir}/{DIST_MODEL_PATH}"
561+
os.makedirs(output_dir, exist_ok=True)
562+
logger.info(f"Saving model files into {output_dir}")
563+
model_file = os.path.join(output_dir, "rank_" + str(paddle.distributed.get_rank()) + ".pd_dist_model")
564+
if os.path.exists(model_file):
565+
os.remove(model_file)
566+
paddle.save(self.model_wrapped.dist_main_program("train"), model_file)
567+
555568
def _save_checkpoint(self, model, metrics=None):
556569

557570
# Save model checkpoint

0 commit comments

Comments
 (0)