@@ -463,8 +463,11 @@ def save_checkpoint(executor,
463
463
save_interval_secs = 600 ,
464
464
main_program = None ):
465
465
"""
466
- Save Variables to Checkpoint Directory
467
-
466
+ Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
467
+ directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
468
+ to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
469
+ The interval time between two save_checkpoint must great than or equal to save_interval_secs.
470
+
468
471
:param dirname
469
472
:param max_num_checkpoints
470
473
:param save_secs
@@ -489,18 +492,19 @@ def save_checkpoint(executor,
489
492
dirname = cur_dir ,
490
493
main_program = main_program ,
491
494
vars = None ,
492
- predicate = is_checkpoint_var ,
495
+ predicate = _is_checkpoint_var ,
493
496
filename = None )
494
497
_write_success (cur_dir )
495
498
_lru_delete (dirname , max_num_checkpoints )
496
499
497
500
498
501
def load_checkpoint (executor , dirname = None , main_program = None ):
499
502
"""
500
- Load Variables from Checkpint Dir
503
+ Load checkpoint from directory by executor,
504
+ it will find lastest checkpoint file and load it auto.
501
505
502
- :param dirname
503
506
:param executor
507
+ :param dirname
504
508
:param main_program
505
509
"""
506
510
@@ -517,14 +521,16 @@ def load_checkpoint(executor, dirname=None, main_program=None):
517
521
executor ,
518
522
dirname = cur_dir ,
519
523
main_program = main_program ,
520
- predicate = is_checkpoint_var ,
524
+ predicate = _is_checkpoint_var ,
521
525
filename = None )
522
526
523
527
524
- def is_checkpoint_var (var ):
528
+ def _is_checkpoint_var (var ):
525
529
"""
526
- VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
527
- VarName will fliter out Gradient
530
+ checkpoint will not save or load all the variables.
531
+ var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded.
532
+
533
+ :param var
528
534
"""
529
535
if var .desc .type () == core .VarDesc .VarType .FEED_MINIBATCH or \
530
536
var .desc .type () == core .VarDesc .VarType .FETCH_LIST or \
@@ -545,9 +551,6 @@ def _interval_secs_exceed(dirname, save_interval_secs):
545
551
546
552
547
553
def _lru_delete (dirname , max_num_checkpoints = 3 ):
548
- """
549
- retain checkpoint nums with max_num_checkpoints
550
- """
551
554
dirs = os .listdir (dirname )
552
555
serials = []
553
556
for serial in dirs :
@@ -568,7 +571,9 @@ def _lru_delete(dirname, max_num_checkpoints=3):
568
571
569
572
def _write_success (dirname ):
570
573
"""
571
- write _SUCCESS to checkpoint dir
574
+ write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct.
575
+
576
+ :param dirname
572
577
"""
573
578
success_file = os .path .join (dirname , SUCCESS_MARK_FILENAME )
574
579
with open (success_file , 'a' ):
@@ -577,7 +582,9 @@ def _write_success(dirname):
577
582
578
583
def _get_lastest_checkpoint_dir (checkpoint_dir ):
579
584
"""
580
- get the biggest number in checkpoint_dir, which has _SUCCESS
585
+ get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
586
+
587
+ :param checkpoint_dir
581
588
"""
582
589
if not checkpoint_dir .strip ():
583
590
return - 1
0 commit comments