@@ -455,37 +455,40 @@ def get_parameter_value_by_name(name, executor, program=None):
455
455
456
456
457
457
SUCCESS_MARK_FILENAME = "_SUCCESS"
458
+ CHECKPOINT_PREFIX = "checkpoint"
459
+ CHECKPOINT_SEPARATOR = "_"
458
460
459
461
460
462
def save_checkpoint (executor ,
461
- dirname = None ,
463
+ checkpoint_dir = None ,
462
464
max_num_checkpoints = 3 ,
463
465
save_interval_secs = 600 ,
464
466
main_program = None ):
465
467
"""
466
468
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
467
469
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
468
470
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
469
- The interval time between two save_checkpoint must great than or equal to save_interval_secs.
471
+ The interval between two saved checkpoints must greater than save_interval_secs.
470
472
471
- :param dirname
473
+ :param executor
474
+ :param checkpoint_dir
472
475
:param max_num_checkpoints
473
- :param save_secs
476
+ :param save_interval_secs
474
477
:param main_program
475
478
"""
476
- if dirname is None :
477
- dirname = os .getcwd ()
479
+ if checkpoint_dir is None :
480
+ checkpoint_dir = os .getcwd ()
478
481
479
- if not os .path .isdir (dirname ):
480
- os .makedirs (dirname )
482
+ if not os .path .isdir (checkpoint_dir ):
483
+ os .makedirs (checkpoint_dir )
481
484
482
- serial = _get_lastest_checkpoint_dir (dirname )
485
+ serial = _get_lastest_checkpoint_dir (checkpoint_dir )
483
486
if serial >= 0 and not _interval_secs_exceed (
484
- os . path . join ( dirname , str ( serial ) ), save_interval_secs ):
487
+ _get_serial_dir ( serial , checkpoint_dir ), save_interval_secs ):
485
488
return
486
489
487
- serial = serial + 1
488
- cur_dir = os . path . join ( dirname , str ( serial ) )
490
+ serial += 1
491
+ cur_dir = _get_serial_dir ( serial , checkpoint_dir )
489
492
490
493
save_vars (
491
494
executor ,
@@ -495,27 +498,28 @@ def save_checkpoint(executor,
495
498
predicate = _is_checkpoint_var ,
496
499
filename = None )
497
500
_write_success (cur_dir )
498
- _lru_delete (dirname , max_num_checkpoints )
501
+ _lru_delete (checkpoint_dir , max_num_checkpoints )
499
502
500
503
501
- def load_checkpoint (executor , dirname = None , main_program = None ):
504
+ def load_checkpoint (executor , checkpoint_dir = None , main_program = None ):
502
505
"""
503
506
Load checkpoint from a directory by executor,
504
- it will find latest checkpoint file and load it auto.
507
+ it will find the most recent saved checkpoint file and load it auto.
505
508
506
509
:param executor
507
- :param dirname
510
+ :param checkpoint_dir
508
511
:param main_program
509
512
"""
510
513
511
- if dirname is None :
512
- dirname = os .getcwd ()
514
+ if checkpoint_dir is None :
515
+ checkpoint_dir = os .getcwd ()
513
516
514
- serial = _get_lastest_checkpoint_dir (dirname )
517
+ serial = _get_lastest_checkpoint_dir (checkpoint_dir )
515
518
516
519
if serial < 0 :
517
520
return
518
- cur_dir = os .path .join (dirname , str (serial ))
521
+
522
+ cur_dir = _get_serial_dir (serial , checkpoint_dir )
519
523
520
524
load_vars (
521
525
executor ,
@@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None):
525
529
filename = None )
526
530
527
531
532
+ def _get_serial_dir (serial , checkpoint_dir ):
533
+ serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str (serial )
534
+ return os .path .join (checkpoint_dir , serial_folder )
535
+
536
+
528
537
def _is_checkpoint_var (var ):
529
538
"""
530
539
the checkpoint will not save or load all the variables.
@@ -577,7 +586,8 @@ def _write_success(dirname):
577
586
"""
578
587
success_file = os .path .join (dirname , SUCCESS_MARK_FILENAME )
579
588
with open (success_file , 'a' ):
580
- pass
589
+ now = time .ctime ()
590
+ success_file .write (now )
581
591
582
592
583
593
def _get_lastest_checkpoint_dir (checkpoint_dir ):
@@ -593,18 +603,20 @@ def has_success(checkpoint_dir, cur_dir):
593
603
"""
594
604
is _SUCCESS in this dir
595
605
"""
596
- if not os .path .isdir (os .path .join (checkpoint_dir , cur_dir )):
597
- return - 1
606
+ _ , serial = cur_dir .split (CHECKPOINT_SEPARATOR )
598
607
599
608
try :
600
- int (cur_dir )
609
+ int (serial )
601
610
except ValueError :
602
611
return - 1
603
612
604
- success_path = os .path .join (checkpoint_dir , cur_dir ,
605
- SUCCESS_MARK_FILENAME )
613
+ if not os .path .isdir (os .path .join (checkpoint_dir , cur_dir )):
614
+ return - 1
615
+
616
+ success_path = os .path .join (
617
+ _get_serial_dir (serial , checkpoint_dir ), SUCCESS_MARK_FILENAME )
606
618
if os .path .isfile (success_path ):
607
- return int (cur_dir )
619
+ return int (serial )
608
620
609
621
if not os .path .isdir (checkpoint_dir ):
610
622
return - 1
0 commit comments