Skip to content

Commit 27b7175

Browse files
committed
update python annotation
1 parent e901de6 commit 27b7175

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

python/paddle/fluid/io.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,11 @@ def save_checkpoint(executor,
463463
save_interval_secs=600,
464464
main_program=None):
465465
"""
466-
Save Variables to Checkpoint Directory
467-
466+
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
467+
directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
468+
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
469+
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
470+
468471
:param dirname
469472
:param max_num_checkpoints
470473
:param save_secs
@@ -489,18 +492,19 @@ def save_checkpoint(executor,
489492
dirname=cur_dir,
490493
main_program=main_program,
491494
vars=None,
492-
predicate=is_checkpoint_var,
495+
predicate=_is_checkpoint_var,
493496
filename=None)
494497
_write_success(cur_dir)
495498
_lru_delete(dirname, max_num_checkpoints)
496499

497500

498501
def load_checkpoint(executor, dirname=None, main_program=None):
499502
"""
500-
Load Variables from Checkpint Dir
503+
Load checkpoint from directory by executor,
504+
it will find lastest checkpoint file and load it auto.
501505
502-
:param dirname
503506
:param executor
507+
:param dirname
504508
:param main_program
505509
"""
506510

@@ -517,14 +521,16 @@ def load_checkpoint(executor, dirname=None, main_program=None):
517521
executor,
518522
dirname=cur_dir,
519523
main_program=main_program,
520-
predicate=is_checkpoint_var,
524+
predicate=_is_checkpoint_var,
521525
filename=None)
522526

523527

524-
def is_checkpoint_var(var):
528+
def _is_checkpoint_var(var):
525529
"""
526-
VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
527-
VarName will fliter out Gradient
530+
checkpoint will not save or load all the variables.
531+
var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded.
532+
533+
:param var
528534
"""
529535
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
530536
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
@@ -545,9 +551,6 @@ def _interval_secs_exceed(dirname, save_interval_secs):
545551

546552

547553
def _lru_delete(dirname, max_num_checkpoints=3):
548-
"""
549-
retain checkpoint nums with max_num_checkpoints
550-
"""
551554
dirs = os.listdir(dirname)
552555
serials = []
553556
for serial in dirs:
@@ -568,7 +571,9 @@ def _lru_delete(dirname, max_num_checkpoints=3):
568571

569572
def _write_success(dirname):
570573
"""
571-
write _SUCCESS to checkpoint dir
574+
write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct.
575+
576+
:param dirname
572577
"""
573578
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
574579
with open(success_file, 'a'):
@@ -577,7 +582,9 @@ def _write_success(dirname):
577582

578583
def _get_lastest_checkpoint_dir(checkpoint_dir):
579584
"""
580-
get the biggest number in checkpoint_dir, which has _SUCCESS
585+
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
586+
587+
:param checkpoint_dir
581588
"""
582589
if not checkpoint_dir.strip():
583590
return -1

0 commit comments

Comments
 (0)