24
24
'save_vars' , 'save_params' , 'save_persistables' , 'load_vars' , 'load_params' ,
25
25
'load_persistables' , 'save_inference_model' , 'load_inference_model' ,
26
26
'get_inference_program' , 'save_checkpoint' , 'load_checkpoint' ,
27
- 'clean_checkpoint'
27
+ 'clean_checkpoint' , 'load_persist_vars_without_grad' ,
28
+ 'save_persist_vars_without_grad' , 'get_latest_checkpoint_serial'
28
29
]
29
30
30
31
@@ -457,95 +458,161 @@ def get_parameter_value_by_name(name, executor, program=None):
457
458
458
459
SUCCESS_MARK_FILENAME = "_SUCCESS"
459
460
CHECKPOINT_PREFIX = "checkpoint"
461
+ MODEL_DIR = "__model__"
462
+ TRAINER_PREFIX = "trainer"
460
463
CHECKPOINT_SEPARATOR = "_"
461
464
462
465
463
466
def save_checkpoint (executor ,
464
- checkpoint_dir = None ,
465
- max_num_checkpoints = 3 ,
466
- save_interval_secs = 600 ,
467
- main_program = None ):
467
+ checkpoint_dir ,
468
+ trainer_id ,
469
+ trainer_args = None ,
470
+ main_program = None ,
471
+ max_num_checkpoints = 3 ):
468
472
"""
469
473
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
470
474
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
471
475
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
472
476
The interval between two saved checkpoints must greater than save_interval_secs.
473
477
474
- :param executor
475
- :param checkpoint_dir
476
- :param max_num_checkpoints
477
- :param save_interval_secs
478
- :param main_program
478
+ :param executor executor for save the value
479
+ :param checkpoint_dir the checkpoint directory
480
+ :param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
481
+ :param main_program will save all variables in program
482
+ :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
479
483
"""
480
484
if checkpoint_dir is None :
481
- checkpoint_dir = os .getcwd ()
485
+ raise ValueError ("'checkpoint_dir' should not be None" )
486
+
487
+ if trainer_args :
488
+ assert isinstance (trainer_args , dict )
482
489
483
490
if not os .path .isdir (checkpoint_dir ):
484
491
os .makedirs (checkpoint_dir )
485
492
486
- serial = _get_lastest_checkpoint_dir (checkpoint_dir )
487
- if serial >= 0 and not _interval_secs_exceed (
488
- _get_serial_dir (serial , checkpoint_dir ), save_interval_secs ):
489
- return
493
+ serial = get_latest_checkpoint_serial (checkpoint_dir ) + 1
494
+ cur_dir = _get_serial_dir (checkpoint_dir , serial )
490
495
491
- serial += 1
492
- cur_dir = _get_serial_dir (serial , checkpoint_dir )
496
+ save_trainer_args (cur_dir , trainer_id , trainer_args )
493
497
494
- save_vars (
495
- executor ,
496
- dirname = cur_dir ,
497
- main_program = main_program ,
498
- vars = None ,
499
- predicate = _is_checkpoint_var ,
500
- filename = None )
501
- _write_success (cur_dir )
502
- _lru_delete (checkpoint_dir , max_num_checkpoints )
498
+ if trainer_id == 0 :
499
+ save_persist_vars_without_grad (executor , cur_dir , main_program )
500
+
501
+ _scroll_delete (checkpoint_dir , max_num_checkpoints )
503
502
504
503
505
- def load_checkpoint (executor , checkpoint_dir = None , main_program = None ):
504
+ def load_checkpoint (executor , checkpoint_dir , serial , main_program ):
506
505
"""
507
506
Load checkpoint from a directory by executor,
508
507
it will find the most recent saved checkpoint file and load it auto.
509
508
510
- :param executor
511
- :param checkpoint_dir
512
- :param main_program
509
+ :param executor executor for load the value
510
+ :param checkpoint_dir the checkpoint directory
511
+ :param serial the serial folder in checkpoint directory will be load
512
+ :param main_program will load all variables in program
513
513
"""
514
514
515
515
if checkpoint_dir is None :
516
- checkpoint_dir = os . getcwd ( )
516
+ raise ValueError ( "' checkpoint_dir' should not be None" )
517
517
518
- serial = _get_lastest_checkpoint_dir (checkpoint_dir )
518
+ if serial is None or serial < 0 :
519
+ raise ValueError ("'serial' should not be None or <0 " )
519
520
520
- if serial < 0 :
521
- return
521
+ if main_program is None :
522
+ raise ValueError ( 'main_program should not be None.' )
522
523
523
- cur_dir = _get_serial_dir (serial , checkpoint_dir )
524
-
525
- load_vars (
526
- executor ,
527
- dirname = cur_dir ,
528
- main_program = main_program ,
529
- predicate = _is_checkpoint_var ,
530
- filename = None )
524
+ cur_dir = _get_serial_dir (checkpoint_dir , serial )
525
+ load_persist_vars_without_grad (executor , cur_dir , main_program , True )
531
526
532
527
533
528
def clean_checkpoint (checkpoint_dir , delete_dir = False ):
534
529
"""
535
530
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
536
531
delete_dir only works when the directory is empty, otherwise, OSError is raised.
532
+
533
+ :param checkpoint_dir
534
+ :param delete_dir
537
535
"""
536
+
538
537
if checkpoint_dir is None :
539
- checkpoint_dir = os . getcwd ( )
540
- _lru_delete (checkpoint_dir , max_num_checkpoints = 0 )
538
+ raise ValueError ( "' checkpoint_dir' should not be None" )
539
+ _scroll_delete (checkpoint_dir , max_num_checkpoints = 0 )
541
540
542
541
if delete_dir and not os .listdir (checkpoint_dir ):
543
542
os .rmdir (checkpoint_dir )
544
543
545
544
546
- def _get_serial_dir (serial , checkpoint_dir ):
547
- serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str (serial )
548
- return os .path .join (checkpoint_dir , serial_folder )
545
+ def load_persist_vars_without_grad (executor ,
546
+ dirname ,
547
+ program ,
548
+ has_model_dir = False ):
549
+ """
550
+ load_persist_vars_without_grad will load variables from a directory by an executor,
551
+ the variable named end with "@GRAD" will not be loaded.
552
+
553
+ :param executor executor for load the value
554
+ :param dirname the checkpoint directory
555
+ :param program will load all variables in program
556
+ :param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
557
+ """
558
+
559
+ if has_model_dir :
560
+ dirname = _get_model_dir (dirname )
561
+
562
+ load_vars (
563
+ executor ,
564
+ dirname = dirname ,
565
+ main_program = program ,
566
+ predicate = _is_checkpoint_var ,
567
+ filename = None )
568
+
569
+
570
+ def save_persist_vars_without_grad (executor , dirname , program ):
571
+ """
572
+ save_persist_vars_without_grad will save variables to a directory by an executor,
573
+ the variable named end with "@GRAD" will not be saved.
574
+
575
+ :param executor executor for load the value
576
+ :param dirname the checkpoint directory
577
+ :param program will load all variables in program
578
+ """
579
+ cur_dir = _get_model_dir (dirname )
580
+ save_vars (
581
+ executor ,
582
+ dirname = cur_dir ,
583
+ main_program = program ,
584
+ vars = None ,
585
+ predicate = _is_checkpoint_var ,
586
+ filename = None )
587
+ _write_success (cur_dir )
588
+
589
+
590
+ def save_trainer_args (dirname , trainer_id , trainer_args ):
591
+ assert isinstance (trainer_args , dict )
592
+
593
+ cur_dir = _get_trainer_dir (dirname , trainer_id )
594
+
595
+ for name , value in trainer_args .iteritems ():
596
+ args_file = os .path .join (cur_dir , name )
597
+ with open (args_file , 'w' ) as f :
598
+ f .write (str (value ))
599
+ _write_success (cur_dir )
600
+
601
+
602
+ def load_trainer_args (checkpoint_dir , serial , trainer_id , trainer_args ):
603
+ assert isinstance (trainer_args , list )
604
+
605
+ cur_dir = _get_serial_dir (checkpoint_dir , serial )
606
+ cur_dir = _get_trainer_dir (cur_dir , trainer_id )
607
+
608
+ ret_values = []
609
+
610
+ for arg in trainer_args :
611
+ cur_file = os .path .join (cur_dir , arg )
612
+ with open (cur_file , 'r' ) as f :
613
+ contents = f .read ()
614
+ ret_values .append (contents .strip ())
615
+ return ret_values
549
616
550
617
551
618
def _is_checkpoint_var (var ):
@@ -559,36 +626,74 @@ def _is_checkpoint_var(var):
559
626
var .desc .type () == core .VarDesc .VarType .FETCH_LIST or \
560
627
var .desc .type () == core .VarDesc .VarType .RAW :
561
628
return False
629
+ # @GRAD are named for gradient variables, checkpoint will not save it.
630
+ if "@GRAD" in var .name :
631
+ return False
632
+ # .trainer_ are named for distribute train variables, checkpoint will not save it.
633
+ if ".trainer_" in var .name :
634
+ return False
562
635
563
- if var .name .endswith ("@GRAD" ):
636
+ # .block is named for distribute train variables, checkpoint will not save it.
637
+ if ".block" in var .name :
564
638
return False
565
639
566
640
return var .persistable
567
641
568
642
569
- def _interval_secs_exceed (dirname , save_interval_secs ):
570
- dir_time = os .path .getmtime (dirname )
571
- if save_interval_secs > (time .time () - dir_time ):
572
- return False
573
- return True
643
+ def _get_dir_serial (dirname ):
644
+ _ , serial = dirname .split (CHECKPOINT_SEPARATOR )
645
+
646
+ try :
647
+ serial_num = int (serial )
648
+ except ValueError :
649
+ serial_num = - 1
650
+ return serial_num
651
+
652
+
653
+ def _get_serial_dir (dirname , serial ):
654
+ serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str (serial )
655
+ serial_dir = os .path .join (dirname , serial_folder )
656
+
657
+ if not os .path .isdir (serial_dir ):
658
+ os .makedirs (serial_dir )
659
+
660
+ return serial_dir
661
+
574
662
663
+ def _get_model_dir (dirname ):
664
+ model_dir = os .path .join (dirname , MODEL_DIR )
575
665
576
- def _lru_delete (dirname , max_num_checkpoints = 3 ):
666
+ if not os .path .isdir (model_dir ):
667
+ os .makedirs (model_dir )
668
+
669
+ return model_dir
670
+
671
+
672
+ def _get_trainer_dir (dirname , trainer_id ):
673
+ trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str (trainer_id )
674
+ trainer_dir = os .path .join (dirname , trainer_folder )
675
+
676
+ if not os .path .isdir (trainer_dir ):
677
+ os .makedirs (trainer_dir )
678
+
679
+ return trainer_dir
680
+
681
+
682
+ def _scroll_delete (dirname , max_num_checkpoints = 3 ):
577
683
dirs = os .listdir (dirname )
578
- serials = []
684
+ serial_map = {}
579
685
for serial in dirs :
580
- try :
581
- serials .append (int (serial ))
582
- except ValueError :
583
- continue
686
+ serial_num = _get_dir_serial (serial )
687
+ serial_map [serial_num ] = serial
584
688
585
- if len (serials ) <= max_num_checkpoints :
689
+ if len (serial_map . keys () ) <= max_num_checkpoints :
586
690
return
587
691
692
+ serials = serial_map .keys ()
588
693
serials .sort (reverse = True )
589
694
serials = serials [max_num_checkpoints :]
590
695
for serial in serials :
591
- cur_dir = os . path . join (dirname , str ( serial ) )
696
+ cur_dir = _get_serial_dir (dirname , serial )
592
697
shutil .rmtree (cur_dir )
593
698
594
699
@@ -604,33 +709,30 @@ def _write_success(dirname):
604
709
f .write (now )
605
710
606
711
607
- def _get_lastest_checkpoint_dir (checkpoint_dir ):
712
+ def get_latest_checkpoint_serial (checkpoint_dir ):
608
713
"""
609
714
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
610
715
611
716
:param checkpoint_dir
612
717
"""
613
- if not checkpoint_dir . strip () :
718
+ if not checkpoint_dir :
614
719
return - 1
615
720
616
721
def has_success (checkpoint_dir , cur_dir ):
617
722
"""
618
723
is _SUCCESS in this dir
619
724
"""
620
- _ , serial = cur_dir .split (CHECKPOINT_SEPARATOR )
621
-
622
- try :
623
- int (serial )
624
- except ValueError :
625
- return - 1
626
725
627
- if not os .path .isdir (os .path .join (checkpoint_dir , cur_dir )):
726
+ serial = _get_dir_serial (cur_dir )
727
+ if serial == - 1 or not os .path .isdir (
728
+ os .path .join (checkpoint_dir , cur_dir )):
628
729
return - 1
629
730
630
731
success_path = os .path .join (
631
- _get_serial_dir (serial , checkpoint_dir ), SUCCESS_MARK_FILENAME )
732
+ _get_serial_dir (checkpoint_dir , serial ), MODEL_DIR ,
733
+ SUCCESS_MARK_FILENAME )
632
734
if os .path .isfile (success_path ):
633
- return int ( serial )
735
+ return serial
634
736
635
737
if not os .path .isdir (checkpoint_dir ):
636
738
return - 1
0 commit comments