23
23
__all__ = [
24
24
'save_vars' , 'save_params' , 'save_persistables' , 'load_vars' , 'load_params' ,
25
25
'load_persistables' , 'save_inference_model' , 'load_inference_model' ,
26
- 'get_inference_program' , 'save_checkpoint' , 'restore_checkpoint '
26
+ 'get_inference_program' , 'save_checkpoint' , 'load_checkpoint '
27
27
]
28
28
29
29
@@ -466,7 +466,7 @@ def save_checkpoint(executor,
466
466
Save Variables to Checkpoint Directory
467
467
468
468
:param dirname
469
- :param keep_max
469
+ :param max_num_checkpoints
470
470
:param save_secs
471
471
:param main_program
472
472
"""
@@ -495,7 +495,7 @@ def save_checkpoint(executor,
495
495
_lru_delete (dirname , max_num_checkpoints )
496
496
497
497
498
- def restore_checkpoint (executor , dirname = None , main_program = None ):
498
+ def load_checkpoint (executor , dirname = None , main_program = None ):
499
499
"""
500
500
Load Variables from Checkpint Dir
501
501
@@ -544,9 +544,9 @@ def _interval_secs_exceed(dirname, save_interval_secs):
544
544
return True
545
545
546
546
547
- def _lru_delete (dirname , keep_max = 3 ):
547
+ def _lru_delete (dirname , max_num_checkpoints = 3 ):
548
548
"""
549
- retain checkpoint nums with keep_max
549
+ retain checkpoint nums with max_num_checkpoints
550
550
"""
551
551
dirs = os .listdir (dirname )
552
552
serials = []
@@ -556,11 +556,11 @@ def _lru_delete(dirname, keep_max=3):
556
556
except ValueError :
557
557
continue
558
558
559
- if len (serials ) <= keep_max :
559
+ if len (serials ) <= max_num_checkpoints :
560
560
return
561
561
562
562
serials .sort (reverse = True )
563
- serials = serials [keep_max :]
563
+ serials = serials [max_num_checkpoints :]
564
564
for serial in serials :
565
565
cur_dir = os .path .join (dirname , str (serial ))
566
566
shutil .rmtree (cur_dir )
0 commit comments