@@ -454,38 +454,36 @@ def get_parameter_value_by_name(name, executor, program=None):
454
454
return get_parameter_value (var , executor )
455
455
456
456
457
- SUCCESS = "_SUCCESS"
458
- BEGIN_SECS = None
457
+ SUCCESS_MARK_FILENAME = "_SUCCESS"
459
458
460
459
461
460
def save_checkpoint (executor ,
462
- dirname ,
463
- keep_max = 3 ,
464
- save_secs = 600 ,
461
+ dirname = None ,
462
+ max_num_checkpoints = 3 ,
463
+ save_interval_secs = 600 ,
465
464
main_program = None ):
466
465
"""
467
- Save Variables to Checkpint Dir
466
+ Save Variables to Checkpoint Directory
468
467
469
468
:param dirname
470
469
:param keep_max
471
470
:param save_secs
472
471
:param main_program
473
472
"""
474
473
if dirname is None :
475
- raise Exception ( "save checkpoint dir can not be none" )
474
+ dirname = os . getcwd ( )
476
475
477
476
if not os .path .isdir (dirname ):
478
477
os .makedirs (dirname )
479
478
480
- global BEGIN_SECS
481
- if BEGIN_SECS is not None :
482
- if time .time () - BEGIN_SECS < save_secs :
483
- return
484
- BEGIN_SECS = time .time ()
479
+ serial = _get_lastest_checkpoint_dir (dirname )
480
+ if serial >= 0 and not _interval_secs_exceed (
481
+ os .path .join (dirname , str (serial )), save_interval_secs ):
482
+ return
485
483
486
- serial = _get_lastest_checkpoint_dir ( dirname ) + 1
484
+ serial = serial + 1
487
485
cur_dir = os .path .join (dirname , str (serial ))
488
- # save_persistables(executor, cur_dir, main_program)
486
+
489
487
save_vars (
490
488
executor ,
491
489
dirname = cur_dir ,
@@ -494,26 +492,27 @@ def save_checkpoint(executor,
494
492
predicate = is_checkpoint_var ,
495
493
filename = None )
496
494
_write_success (cur_dir )
497
- _lru_delete (dirname , keep_max )
495
+ _lru_delete (dirname , max_num_checkpoints )
498
496
499
497
500
- def restore_checkpoint (dirname , executor , main_program = None ):
498
+ def restore_checkpoint (executor , dirname = None , main_program = None ):
501
499
"""
502
500
Load Variables from Checkpint Dir
503
501
504
502
:param dirname
505
503
:param executor
506
504
:param main_program
507
505
"""
508
- if dirname is None and os .path .isdir (dirname ):
509
- raise Exception ("restore checkpoint can not load variables from %s" %
510
- dirname )
506
+
507
+ if dirname is None :
508
+ dirname = os .getcwd ()
509
+
511
510
serial = _get_lastest_checkpoint_dir (dirname )
512
511
513
512
if serial < 0 :
514
513
return
515
514
cur_dir = os .path .join (dirname , str (serial ))
516
- # load_persistables(executor, cur_dir, main_program)
515
+
517
516
load_vars (
518
517
executor ,
519
518
dirname = cur_dir ,
@@ -523,6 +522,10 @@ def restore_checkpoint(dirname, executor, main_program=None):
523
522
524
523
525
524
def is_checkpoint_var (var ):
525
+ """
526
+ VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
527
+ VarName will fliter out Gradient
528
+ """
526
529
if var .desc .type () == core .VarDesc .VarType .FEED_MINIBATCH or \
527
530
var .desc .type () == core .VarDesc .VarType .FETCH_LIST or \
528
531
var .desc .type () == core .VarDesc .VarType .RAW :
@@ -534,6 +537,13 @@ def is_checkpoint_var(var):
534
537
return var .persistable
535
538
536
539
540
+ def _interval_secs_exceed (dirname , save_interval_secs ):
541
+ dir_time = os .path .getmtime (dirname )
542
+ if save_interval_secs > (time .time () - dir_time ):
543
+ return False
544
+ return True
545
+
546
+
537
547
def _lru_delete (dirname , keep_max = 3 ):
538
548
"""
539
549
retain checkpoint nums with keep_max
@@ -560,7 +570,7 @@ def _write_success(dirname):
560
570
"""
561
571
write _SUCCESS to checkpoint dir
562
572
"""
563
- success_file = os .path .join (dirname , SUCCESS )
573
+ success_file = os .path .join (dirname , SUCCESS_MARK_FILENAME )
564
574
with open (success_file , 'a' ):
565
575
pass
566
576
@@ -584,7 +594,8 @@ def has_success(checkpoint_dir, cur_dir):
584
594
except ValueError :
585
595
return - 1
586
596
587
- success_path = os .path .join (checkpoint_dir , cur_dir , SUCCESS )
597
+ success_path = os .path .join (checkpoint_dir , cur_dir ,
598
+ SUCCESS_MARK_FILENAME )
588
599
if os .path .isfile (success_path ):
589
600
return int (cur_dir )
590
601
0 commit comments