@@ -492,7 +492,7 @@ def save_checkpoint(executor,
492
492
if not os .path .isdir (checkpoint_dir ):
493
493
os .makedirs (checkpoint_dir )
494
494
495
- serial = _get_latest_checkpoint_dir (checkpoint_dir ) + 1
495
+ serial = get_latest_checkpoint_serial (checkpoint_dir ) + 1
496
496
cur_dir = _get_serial_dir (checkpoint_dir , serial )
497
497
498
498
save_trainer_args (cur_dir , trainer_id , trainer_args )
@@ -503,18 +503,6 @@ def save_checkpoint(executor,
503
503
_lru_delete (checkpoint_dir , max_num_checkpoints )
504
504
505
505
506
- def get_latest_checkpoint_serial (checkpoint_dir ):
507
- """
508
- If the directory have checkpoint files, it will return latest checkpoint directory serial number
509
-
510
- :param checkpoint_dir
511
- """
512
- serial = _get_latest_checkpoint_dir (checkpoint_dir )
513
- if serial < 0 :
514
- return None
515
- return serial
516
-
517
-
518
506
def load_checkpoint (executor , checkpoint_dir , serial , main_program ):
519
507
"""
520
508
Load checkpoint from a directory by executor,
@@ -527,17 +515,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
527
515
"""
528
516
529
517
if checkpoint_dir is None :
530
- raise ValueError (
531
- "The values of 'checkpoint_dir' or 'serial' should not be None" )
518
+ raise ValueError ("The values of 'checkpoint_dir' should not be None" )
532
519
533
520
if serial is None or serial < 0 :
534
521
raise ValueError ("The values of 'serial' should not be None or <0 " )
535
522
536
523
if main_program is None :
537
- raise ValueError ("The values of 'main_program' should not be None" )
524
+ raise ValueError ('main_program should not be None.' )
538
525
539
526
cur_dir = _get_serial_dir (checkpoint_dir , serial )
540
- load_persist_vars_without_grad (executor , cur_dir , main_program )
527
+ load_persist_vars_without_grad (executor , cur_dir , main_program , True )
541
528
542
529
543
530
def clean_checkpoint (checkpoint_dir , delete_dir = False ):
@@ -557,18 +544,21 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
557
544
os .rmdir (checkpoint_dir )
558
545
559
546
560
- def load_persist_vars_without_grad (executor , dirname , program , nest = True ):
547
+ def load_persist_vars_without_grad (executor ,
548
+ dirname ,
549
+ program ,
550
+ has_model_dir = False ):
561
551
"""
562
552
load_persist_vars_without_grad will load variables from a directory by an executor,
563
553
the variable named end with "@GRAD" will not be loaded.
564
554
565
- :param executor
566
- :param dirname
567
- :param program
568
- :param nest
555
+ :param executor executor for load the value
556
+ :param dirname the checkpoint directory
557
+ :param program will load all variables in program
558
+ :param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
569
559
"""
570
560
571
- if nest :
561
+ if has_model_dir :
572
562
dirname = _get_model_dir (dirname )
573
563
574
564
load_vars (
@@ -584,9 +574,9 @@ def save_persist_vars_without_grad(executor, dirname, program):
584
574
save_persist_vars_without_grad will save variables to a directory by an executor,
585
575
the variable named end with "@GRAD" will not be saved.
586
576
587
- :param executor
588
- :param dirname
589
- :param program
577
+ :param executor executor for load the value
578
+ :param dirname the checkpoint directory
579
+ :param program will load all variables in program
590
580
"""
591
581
cur_dir = _get_model_dir (dirname )
592
582
save_vars (
@@ -722,7 +712,7 @@ def _write_success(dirname):
722
712
f .write (now )
723
713
724
714
725
- def _get_latest_checkpoint_dir (checkpoint_dir ):
715
+ def get_latest_checkpoint_serial (checkpoint_dir ):
726
716
"""
727
717
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
728
718
0 commit comments