Skip to content

Commit 2412dee

Browse files
committed
code optimized
1 parent 06aa23b commit 2412dee

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

python/paddle/fluid/io.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -454,38 +454,36 @@ def get_parameter_value_by_name(name, executor, program=None):
454454
return get_parameter_value(var, executor)
455455

456456

457-
SUCCESS = "_SUCCESS"
458-
BEGIN_SECS = None
457+
SUCCESS_MARK_FILENAME = "_SUCCESS"
459458

460459

461460
def save_checkpoint(executor,
462-
dirname,
463-
keep_max=3,
464-
save_secs=600,
461+
dirname=None,
462+
max_num_checkpoints=3,
463+
save_interval_secs=600,
465464
main_program=None):
466465
"""
467-
Save Variables to Checkpint Dir
466+
Save Variables to Checkpoint Directory
468467
469468
:param dirname
470469
:param keep_max
471470
:param save_secs
472471
:param main_program
473472
"""
474473
if dirname is None:
475-
raise Exception("save checkpoint dir can not be none")
474+
dirname = os.getcwd()
476475

477476
if not os.path.isdir(dirname):
478477
os.makedirs(dirname)
479478

480-
global BEGIN_SECS
481-
if BEGIN_SECS is not None:
482-
if time.time() - BEGIN_SECS < save_secs:
483-
return
484-
BEGIN_SECS = time.time()
479+
serial = _get_lastest_checkpoint_dir(dirname)
480+
if serial >= 0 and not _interval_secs_exceed(
481+
os.path.join(dirname, str(serial)), save_interval_secs):
482+
return
485483

486-
serial = _get_lastest_checkpoint_dir(dirname) + 1
484+
serial = serial + 1
487485
cur_dir = os.path.join(dirname, str(serial))
488-
# save_persistables(executor, cur_dir, main_program)
486+
489487
save_vars(
490488
executor,
491489
dirname=cur_dir,
@@ -494,26 +492,27 @@ def save_checkpoint(executor,
494492
predicate=is_checkpoint_var,
495493
filename=None)
496494
_write_success(cur_dir)
497-
_lru_delete(dirname, keep_max)
495+
_lru_delete(dirname, max_num_checkpoints)
498496

499497

500-
def restore_checkpoint(dirname, executor, main_program=None):
498+
def restore_checkpoint(executor, dirname=None, main_program=None):
501499
"""
502500
Load Variables from Checkpint Dir
503501
504502
:param dirname
505503
:param executor
506504
:param main_program
507505
"""
508-
if dirname is None and os.path.isdir(dirname):
509-
raise Exception("restore checkpoint can not load variables from %s" %
510-
dirname)
506+
507+
if dirname is None:
508+
dirname = os.getcwd()
509+
511510
serial = _get_lastest_checkpoint_dir(dirname)
512511

513512
if serial < 0:
514513
return
515514
cur_dir = os.path.join(dirname, str(serial))
516-
# load_persistables(executor, cur_dir, main_program)
515+
517516
load_vars(
518517
executor,
519518
dirname=cur_dir,
@@ -523,6 +522,10 @@ def restore_checkpoint(dirname, executor, main_program=None):
523522

524523

525524
def is_checkpoint_var(var):
525+
"""
526+
VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
527+
VarName will fliter out Gradient
528+
"""
526529
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
527530
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
528531
var.desc.type() == core.VarDesc.VarType.RAW:
@@ -534,6 +537,13 @@ def is_checkpoint_var(var):
534537
return var.persistable
535538

536539

540+
def _interval_secs_exceed(dirname, save_interval_secs):
541+
dir_time = os.path.getmtime(dirname)
542+
if save_interval_secs > (time.time() - dir_time):
543+
return False
544+
return True
545+
546+
537547
def _lru_delete(dirname, keep_max=3):
538548
"""
539549
retain checkpoint nums with keep_max
@@ -560,7 +570,7 @@ def _write_success(dirname):
560570
"""
561571
write _SUCCESS to checkpoint dir
562572
"""
563-
success_file = os.path.join(dirname, SUCCESS)
573+
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
564574
with open(success_file, 'a'):
565575
pass
566576

@@ -584,7 +594,8 @@ def has_success(checkpoint_dir, cur_dir):
584594
except ValueError:
585595
return -1
586596

587-
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
597+
success_path = os.path.join(checkpoint_dir, cur_dir,
598+
SUCCESS_MARK_FILENAME)
588599
if os.path.isfile(success_path):
589600
return int(cur_dir)
590601

0 commit comments

Comments
 (0)