Skip to content

Commit d96b442

Browse files
committed
rename checkpoint folder to checkpoint_serial
1 parent 9d98534 commit d96b442

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

python/paddle/fluid/io.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -455,37 +455,40 @@ def get_parameter_value_by_name(name, executor, program=None):
455455

456456

457457
SUCCESS_MARK_FILENAME = "_SUCCESS"
458+
CHECKPOINT_PREFIX = "checkpoint"
459+
CHECKPOINT_SEPARATOR = "_"
458460

459461

460462
def save_checkpoint(executor,
461-
dirname=None,
463+
checkpoint_dir=None,
462464
max_num_checkpoints=3,
463465
save_interval_secs=600,
464466
main_program=None):
465467
"""
466468
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
467469
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
468470
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
469-
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
471+
The interval between two saved checkpoints must greater than save_interval_secs.
470472
471-
:param dirname
473+
:param executor
474+
:param checkpoint_dir
472475
:param max_num_checkpoints
473-
:param save_secs
476+
:param save_interval_secs
474477
:param main_program
475478
"""
476-
if dirname is None:
477-
dirname = os.getcwd()
479+
if checkpoint_dir is None:
480+
checkpoint_dir = os.getcwd()
478481

479-
if not os.path.isdir(dirname):
480-
os.makedirs(dirname)
482+
if not os.path.isdir(checkpoint_dir):
483+
os.makedirs(checkpoint_dir)
481484

482-
serial = _get_lastest_checkpoint_dir(dirname)
485+
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
483486
if serial >= 0 and not _interval_secs_exceed(
484-
os.path.join(dirname, str(serial)), save_interval_secs):
487+
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
485488
return
486489

487-
serial = serial + 1
488-
cur_dir = os.path.join(dirname, str(serial))
490+
serial += 1
491+
cur_dir = _get_serial_dir(serial, checkpoint_dir)
489492

490493
save_vars(
491494
executor,
@@ -495,27 +498,28 @@ def save_checkpoint(executor,
495498
predicate=_is_checkpoint_var,
496499
filename=None)
497500
_write_success(cur_dir)
498-
_lru_delete(dirname, max_num_checkpoints)
501+
_lru_delete(checkpoint_dir, max_num_checkpoints)
499502

500503

501-
def load_checkpoint(executor, dirname=None, main_program=None):
504+
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
502505
"""
503506
Load checkpoint from a directory by executor,
504-
it will find latest checkpoint file and load it auto.
507+
it will find the most recent saved checkpoint file and load it auto.
505508
506509
:param executor
507-
:param dirname
510+
:param checkpoint_dir
508511
:param main_program
509512
"""
510513

511-
if dirname is None:
512-
dirname = os.getcwd()
514+
if checkpoint_dir is None:
515+
checkpoint_dir = os.getcwd()
513516

514-
serial = _get_lastest_checkpoint_dir(dirname)
517+
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
515518

516519
if serial < 0:
517520
return
518-
cur_dir = os.path.join(dirname, str(serial))
521+
522+
cur_dir = _get_serial_dir(serial, checkpoint_dir)
519523

520524
load_vars(
521525
executor,
@@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None):
525529
filename=None)
526530

527531

532+
def _get_serial_dir(serial, checkpoint_dir):
533+
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
534+
return os.path.join(checkpoint_dir, serial_folder)
535+
536+
528537
def _is_checkpoint_var(var):
529538
"""
530539
the checkpoint will not save or load all the variables.
@@ -577,7 +586,8 @@ def _write_success(dirname):
577586
"""
578587
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
579588
with open(success_file, 'a'):
580-
pass
589+
now = time.ctime()
590+
success_file.write(now)
581591

582592

583593
def _get_lastest_checkpoint_dir(checkpoint_dir):
@@ -593,18 +603,20 @@ def has_success(checkpoint_dir, cur_dir):
593603
"""
594604
is _SUCCESS in this dir
595605
"""
596-
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
597-
return -1
606+
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
598607

599608
try:
600-
int(cur_dir)
609+
int(serial)
601610
except ValueError:
602611
return -1
603612

604-
success_path = os.path.join(checkpoint_dir, cur_dir,
605-
SUCCESS_MARK_FILENAME)
613+
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
614+
return -1
615+
616+
success_path = os.path.join(
617+
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
606618
if os.path.isfile(success_path):
607-
return int(cur_dir)
619+
return int(serial)
608620

609621
if not os.path.isdir(checkpoint_dir):
610622
return -1

0 commit comments

Comments
 (0)