Skip to content

Commit d896134

Browse files
authored
Merge pull request #10878 from seiriosPlus/new_api_about_cpkt
New api about checkpoint and models
2 parents 7bcc980 + bf2c53a commit d896134

File tree

4 files changed

+370
-83
lines changed

4 files changed

+370
-83
lines changed

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from trainer import EndEpochEvent
2727
from trainer import BeginStepEvent
2828
from trainer import EndStepEvent
29+
from trainer import CheckpointConfig
2930

3031
import inferencer
3132
from inferencer import Inferencer

python/paddle/fluid/io.py

Lines changed: 174 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
2525
'load_persistables', 'save_inference_model', 'load_inference_model',
2626
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
27-
'clean_checkpoint'
27+
'clean_checkpoint', 'load_persist_vars_without_grad',
28+
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial'
2829
]
2930

3031

@@ -457,95 +458,161 @@ def get_parameter_value_by_name(name, executor, program=None):
457458

458459
SUCCESS_MARK_FILENAME = "_SUCCESS"
459460
CHECKPOINT_PREFIX = "checkpoint"
461+
MODEL_DIR = "__model__"
462+
TRAINER_PREFIX = "trainer"
460463
CHECKPOINT_SEPARATOR = "_"
461464

462465

463466
def save_checkpoint(executor,
464-
checkpoint_dir=None,
465-
max_num_checkpoints=3,
466-
save_interval_secs=600,
467-
main_program=None):
467+
checkpoint_dir,
468+
trainer_id,
469+
trainer_args=None,
470+
main_program=None,
471+
max_num_checkpoints=3):
468472
"""
469473
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
470474
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
471475
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
472476
The interval between two saved checkpoints must greater than save_interval_secs.
473477
474-
:param executor
475-
:param checkpoint_dir
476-
:param max_num_checkpoints
477-
:param save_interval_secs
478-
:param main_program
478+
:param executor executor for save the value
479+
:param checkpoint_dir the checkpoint directory
480+
:param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
481+
:param main_program will save all variables in program
482+
:param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
479483
"""
480484
if checkpoint_dir is None:
481-
checkpoint_dir = os.getcwd()
485+
raise ValueError("'checkpoint_dir' should not be None")
486+
487+
if trainer_args:
488+
assert isinstance(trainer_args, dict)
482489

483490
if not os.path.isdir(checkpoint_dir):
484491
os.makedirs(checkpoint_dir)
485492

486-
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
487-
if serial >= 0 and not _interval_secs_exceed(
488-
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
489-
return
493+
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
494+
cur_dir = _get_serial_dir(checkpoint_dir, serial)
490495

491-
serial += 1
492-
cur_dir = _get_serial_dir(serial, checkpoint_dir)
496+
save_trainer_args(cur_dir, trainer_id, trainer_args)
493497

494-
save_vars(
495-
executor,
496-
dirname=cur_dir,
497-
main_program=main_program,
498-
vars=None,
499-
predicate=_is_checkpoint_var,
500-
filename=None)
501-
_write_success(cur_dir)
502-
_lru_delete(checkpoint_dir, max_num_checkpoints)
498+
if trainer_id == 0:
499+
save_persist_vars_without_grad(executor, cur_dir, main_program)
500+
501+
_scroll_delete(checkpoint_dir, max_num_checkpoints)
503502

504503

505-
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
504+
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
506505
"""
507506
Load checkpoint from a directory by executor,
508507
it will find the most recent saved checkpoint file and load it auto.
509508
510-
:param executor
511-
:param checkpoint_dir
512-
:param main_program
509+
:param executor executor for load the value
510+
:param checkpoint_dir the checkpoint directory
511+
:param serial the serial folder in checkpoint directory will be load
512+
:param main_program will load all variables in program
513513
"""
514514

515515
if checkpoint_dir is None:
516-
checkpoint_dir = os.getcwd()
516+
raise ValueError("'checkpoint_dir' should not be None")
517517

518-
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
518+
if serial is None or serial < 0:
519+
raise ValueError("'serial' should not be None or <0 ")
519520

520-
if serial < 0:
521-
return
521+
if main_program is None:
522+
raise ValueError('main_program should not be None.')
522523

523-
cur_dir = _get_serial_dir(serial, checkpoint_dir)
524-
525-
load_vars(
526-
executor,
527-
dirname=cur_dir,
528-
main_program=main_program,
529-
predicate=_is_checkpoint_var,
530-
filename=None)
524+
cur_dir = _get_serial_dir(checkpoint_dir, serial)
525+
load_persist_vars_without_grad(executor, cur_dir, main_program, True)
531526

532527

533528
def clean_checkpoint(checkpoint_dir, delete_dir=False):
534529
"""
535530
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
536531
delete_dir only works when the directory is empty, otherwise, OSError is raised.
532+
533+
:param checkpoint_dir
534+
:param delete_dir
537535
"""
536+
538537
if checkpoint_dir is None:
539-
checkpoint_dir = os.getcwd()
540-
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
538+
raise ValueError("'checkpoint_dir' should not be None")
539+
_scroll_delete(checkpoint_dir, max_num_checkpoints=0)
541540

542541
if delete_dir and not os.listdir(checkpoint_dir):
543542
os.rmdir(checkpoint_dir)
544543

545544

546-
def _get_serial_dir(serial, checkpoint_dir):
547-
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
548-
return os.path.join(checkpoint_dir, serial_folder)
545+
def load_persist_vars_without_grad(executor,
546+
dirname,
547+
program,
548+
has_model_dir=False):
549+
"""
550+
load_persist_vars_without_grad will load variables from a directory by an executor,
551+
the variable named end with "@GRAD" will not be loaded.
552+
553+
:param executor executor for load the value
554+
:param dirname the checkpoint directory
555+
:param program will load all variables in program
556+
:param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
557+
"""
558+
559+
if has_model_dir:
560+
dirname = _get_model_dir(dirname)
561+
562+
load_vars(
563+
executor,
564+
dirname=dirname,
565+
main_program=program,
566+
predicate=_is_checkpoint_var,
567+
filename=None)
568+
569+
570+
def save_persist_vars_without_grad(executor, dirname, program):
571+
"""
572+
save_persist_vars_without_grad will save variables to a directory by an executor,
573+
the variable named end with "@GRAD" will not be saved.
574+
575+
:param executor executor for load the value
576+
:param dirname the checkpoint directory
577+
:param program will load all variables in program
578+
"""
579+
cur_dir = _get_model_dir(dirname)
580+
save_vars(
581+
executor,
582+
dirname=cur_dir,
583+
main_program=program,
584+
vars=None,
585+
predicate=_is_checkpoint_var,
586+
filename=None)
587+
_write_success(cur_dir)
588+
589+
590+
def save_trainer_args(dirname, trainer_id, trainer_args):
591+
assert isinstance(trainer_args, dict)
592+
593+
cur_dir = _get_trainer_dir(dirname, trainer_id)
594+
595+
for name, value in trainer_args.iteritems():
596+
args_file = os.path.join(cur_dir, name)
597+
with open(args_file, 'w') as f:
598+
f.write(str(value))
599+
_write_success(cur_dir)
600+
601+
602+
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
603+
assert isinstance(trainer_args, list)
604+
605+
cur_dir = _get_serial_dir(checkpoint_dir, serial)
606+
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
607+
608+
ret_values = []
609+
610+
for arg in trainer_args:
611+
cur_file = os.path.join(cur_dir, arg)
612+
with open(cur_file, 'r') as f:
613+
contents = f.read()
614+
ret_values.append(contents.strip())
615+
return ret_values
549616

550617

551618
def _is_checkpoint_var(var):
@@ -559,36 +626,74 @@ def _is_checkpoint_var(var):
559626
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
560627
var.desc.type() == core.VarDesc.VarType.RAW:
561628
return False
629+
# @GRAD are named for gradient variables, checkpoint will not save it.
630+
if "@GRAD" in var.name:
631+
return False
632+
# .trainer_ are named for distribute train variables, checkpoint will not save it.
633+
if ".trainer_" in var.name:
634+
return False
562635

563-
if var.name.endswith("@GRAD"):
636+
# .block is named for distribute train variables, checkpoint will not save it.
637+
if ".block" in var.name:
564638
return False
565639

566640
return var.persistable
567641

568642

569-
def _interval_secs_exceed(dirname, save_interval_secs):
570-
dir_time = os.path.getmtime(dirname)
571-
if save_interval_secs > (time.time() - dir_time):
572-
return False
573-
return True
643+
def _get_dir_serial(dirname):
644+
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
645+
646+
try:
647+
serial_num = int(serial)
648+
except ValueError:
649+
serial_num = -1
650+
return serial_num
651+
652+
653+
def _get_serial_dir(dirname, serial):
654+
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
655+
serial_dir = os.path.join(dirname, serial_folder)
656+
657+
if not os.path.isdir(serial_dir):
658+
os.makedirs(serial_dir)
659+
660+
return serial_dir
661+
574662

663+
def _get_model_dir(dirname):
664+
model_dir = os.path.join(dirname, MODEL_DIR)
575665

576-
def _lru_delete(dirname, max_num_checkpoints=3):
666+
if not os.path.isdir(model_dir):
667+
os.makedirs(model_dir)
668+
669+
return model_dir
670+
671+
672+
def _get_trainer_dir(dirname, trainer_id):
673+
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
674+
trainer_dir = os.path.join(dirname, trainer_folder)
675+
676+
if not os.path.isdir(trainer_dir):
677+
os.makedirs(trainer_dir)
678+
679+
return trainer_dir
680+
681+
682+
def _scroll_delete(dirname, max_num_checkpoints=3):
577683
dirs = os.listdir(dirname)
578-
serials = []
684+
serial_map = {}
579685
for serial in dirs:
580-
try:
581-
serials.append(int(serial))
582-
except ValueError:
583-
continue
686+
serial_num = _get_dir_serial(serial)
687+
serial_map[serial_num] = serial
584688

585-
if len(serials) <= max_num_checkpoints:
689+
if len(serial_map.keys()) <= max_num_checkpoints:
586690
return
587691

692+
serials = serial_map.keys()
588693
serials.sort(reverse=True)
589694
serials = serials[max_num_checkpoints:]
590695
for serial in serials:
591-
cur_dir = os.path.join(dirname, str(serial))
696+
cur_dir = _get_serial_dir(dirname, serial)
592697
shutil.rmtree(cur_dir)
593698

594699

@@ -604,33 +709,30 @@ def _write_success(dirname):
604709
f.write(now)
605710

606711

607-
def _get_lastest_checkpoint_dir(checkpoint_dir):
712+
def get_latest_checkpoint_serial(checkpoint_dir):
608713
"""
609714
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
610715
611716
:param checkpoint_dir
612717
"""
613-
if not checkpoint_dir.strip():
718+
if not checkpoint_dir:
614719
return -1
615720

616721
def has_success(checkpoint_dir, cur_dir):
617722
"""
618723
is _SUCCESS in this dir
619724
"""
620-
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
621-
622-
try:
623-
int(serial)
624-
except ValueError:
625-
return -1
626725

627-
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
726+
serial = _get_dir_serial(cur_dir)
727+
if serial == -1 or not os.path.isdir(
728+
os.path.join(checkpoint_dir, cur_dir)):
628729
return -1
629730

630731
success_path = os.path.join(
631-
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
732+
_get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
733+
SUCCESS_MARK_FILENAME)
632734
if os.path.isfile(success_path):
633-
return int(serial)
735+
return serial
634736

635737
if not os.path.isdir(checkpoint_dir):
636738
return -1

0 commit comments

Comments
 (0)