Skip to content

Commit 96bc025

Browse files
authored
save checkpoint files with step and keep recent files (#1031)
This commit saves checkpoint to `save_ckpt-step` (e.g. `model.ckpt-100`) instead of `save_ckpt` (e.g. `model.ckpt`), and keeps 5 recent checkpoint files (this is a default value of `tf.Saver`). Such thing is conducted by `tf.Saver`. To not break any behaviors, a symlink will then be made from `model.ckpt-100` to `model.ckpt`. (Usually such thing should be controlled by `checkpoint` file, but deepmd-kit doesn't read this file.) This can fix #1023, as (1) we made symlink after a checkpoint has been already saved; (2) if something is still wrong, one can use a previous checkpoint instead.
1 parent 8048e77 commit 96bc025

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

deepmd/train/trainer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
import logging
33
import os
4+
import glob
45
import time
56
import shutil
67
import google.protobuf.message
@@ -403,7 +404,7 @@ def _init_session(self):
403404
# Initializes or restore global variables
404405
init_op = tf.global_variables_initializer()
405406
if self.run_opt.is_chief:
406-
self.saver = tf.train.Saver()
407+
self.saver = tf.train.Saver(save_relative_paths=True)
407408
if self.run_opt.init_mode == 'init_from_scratch' :
408409
log.info("initialize model from scratch")
409410
run_sess(self.sess, init_op)
@@ -536,13 +537,24 @@ def train (self, train_data = None, valid_data=None) :
536537
train_time = 0
537538
if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.saver is not None:
538539
try:
539-
self.saver.save (self.sess, os.path.join(os.getcwd(), self.save_ckpt))
540+
ckpt_prefix = self.saver.save (self.sess, os.path.join(os.getcwd(), self.save_ckpt), global_step=cur_batch)
540541
except google.protobuf.message.DecodeError as e:
541542
raise GraphTooLargeError(
542543
"The graph size exceeds 2 GB, the hard limitation of protobuf."
543544
" Then a DecodeError was raised by protobuf. You should "
544545
"reduce the size of your model."
545546
) from e
547+
# make symlinks from prefix with step to that without step to break nothing
548+
# get all checkpoint files
549+
original_files = glob.glob(ckpt_prefix + ".*")
550+
for ori_ff in original_files:
551+
new_ff = self.save_ckpt + ori_ff[len(ckpt_prefix):]
552+
try:
553+
# remove old one
554+
os.remove(new_ff)
555+
except OSError:
556+
pass
557+
os.symlink(ori_ff, new_ff)
546558
log.info("saved checkpoint %s" % self.save_ckpt)
547559
if self.run_opt.is_chief:
548560
fp.close ()

0 commit comments

Comments
 (0)