Skip to content

Commit 9735f25

Browse files
committed
optimized
1 parent bfdcf18 commit 9735f25

File tree

2 files changed

+20
-32
lines changed

2 files changed

+20
-32
lines changed

python/paddle/fluid/io.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def save_checkpoint(executor,
492492
if not os.path.isdir(checkpoint_dir):
493493
os.makedirs(checkpoint_dir)
494494

495-
serial = _get_latest_checkpoint_dir(checkpoint_dir) + 1
495+
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
496496
cur_dir = _get_serial_dir(checkpoint_dir, serial)
497497

498498
save_trainer_args(cur_dir, trainer_id, trainer_args)
@@ -503,18 +503,6 @@ def save_checkpoint(executor,
503503
_lru_delete(checkpoint_dir, max_num_checkpoints)
504504

505505

506-
def get_latest_checkpoint_serial(checkpoint_dir):
507-
"""
508-
If the directory have checkpoint files, it will return latest checkpoint directory serial number
509-
510-
:param checkpoint_dir
511-
"""
512-
serial = _get_latest_checkpoint_dir(checkpoint_dir)
513-
if serial < 0:
514-
return None
515-
return serial
516-
517-
518506
def load_checkpoint(executor, checkpoint_dir, serial, main_program):
519507
"""
520508
Load checkpoint from a directory by executor,
@@ -527,17 +515,16 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
527515
"""
528516

529517
if checkpoint_dir is None:
530-
raise ValueError(
531-
"The values of 'checkpoint_dir' or 'serial' should not be None")
518+
raise ValueError("The values of 'checkpoint_dir' should not be None")
532519

533520
if serial is None or serial < 0:
534521
raise ValueError("The values of 'serial' should not be None or <0 ")
535522

536523
if main_program is None:
537-
raise ValueError("The values of 'main_program'should not be None")
524+
raise ValueError('main_program should not be None.')
538525

539526
cur_dir = _get_serial_dir(checkpoint_dir, serial)
540-
load_persist_vars_without_grad(executor, cur_dir, main_program)
527+
load_persist_vars_without_grad(executor, cur_dir, main_program, True)
541528

542529

543530
def clean_checkpoint(checkpoint_dir, delete_dir=False):
@@ -557,18 +544,21 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
557544
os.rmdir(checkpoint_dir)
558545

559546

560-
def load_persist_vars_without_grad(executor, dirname, program, nest=True):
547+
def load_persist_vars_without_grad(executor,
548+
dirname,
549+
program,
550+
has_model_dir=False):
561551
"""
562552
load_persist_vars_without_grad will load variables from a directory by an executor,
563553
the variable named end with "@GRAD" will not be loaded.
564554
565-
:param executor
566-
:param dirname
567-
:param program
568-
:param nest
555+
:param executor executor for load the value
556+
:param dirname the checkpoint directory
557+
:param program will load all variables in program
558+
:param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
569559
"""
570560

571-
if nest:
561+
if has_model_dir:
572562
dirname = _get_model_dir(dirname)
573563

574564
load_vars(
@@ -584,9 +574,9 @@ def save_persist_vars_without_grad(executor, dirname, program):
584574
save_persist_vars_without_grad will save variables to a directory by an executor,
585575
the variable named end with "@GRAD" will not be saved.
586576
587-
:param executor
588-
:param dirname
589-
:param program
577+
:param executor executor for load the value
578+
:param dirname the checkpoint directory
579+
:param program will load all variables in program
590580
"""
591581
cur_dir = _get_model_dir(dirname)
592582
save_vars(
@@ -722,7 +712,7 @@ def _write_success(dirname):
722712
f.write(now)
723713

724714

725-
def _get_latest_checkpoint_dir(checkpoint_dir):
715+
def get_latest_checkpoint_serial(checkpoint_dir):
726716
"""
727717
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
728718

python/paddle/fluid/trainer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,9 @@ def __init__(self,
146146
"The checkpoint_config shoule be an instance of CheckpointConfig"
147147
)
148148
else:
149-
self.checkpoint.load_serial = io.get_latest_checkpoint_serial(
149+
serial = io.get_latest_checkpoint_serial(
150150
self.checkpoint.checkpoint_dir)
151+
self.checkpoint.load_serial = serial if serial >= 0 else None
151152

152153
self.scope = core.Scope()
153154

@@ -194,10 +195,7 @@ def __init__(self,
194195
if param_path and os.path.isdir(param_path):
195196
# load params from param_path into scope
196197
io.load_persist_vars_without_grad(
197-
exe,
198-
dirname=param_path,
199-
program=self.startup_program,
200-
nest=False)
198+
exe, dirname=param_path, program=self.startup_program)
201199

202200
def _transpile_nccl2_dist(self):
203201
# PADDLE_TRAINER_IPS

0 commit comments

Comments
 (0)