Skip to content

Commit 397a69d

Browse files
authored
Merge pull request #10532 from seiriosPlus/checkpoint
add checkpoint util class and implement
2 parents 20bdc3e + 2c47e06 commit 397a69d

File tree

2 files changed

+198
-9
lines changed

2 files changed

+198
-9
lines changed

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
190190
for (auto &var : sparse_vars) {
191191
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
192192
}
193+
193194
rpc_service_->SetCond(1);
194195
// FIXME(typhoonzero): use another condition to sync wait clients get.
195196
rpc_service_->WaitClientGet(fan_in);

python/paddle/fluid/io.py

Lines changed: 197 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,18 @@
1313
# limitations under the License.
1414

1515
import os
16+
import time
17+
import shutil
1618

1719
from paddle.fluid.evaluator import Evaluator
1820
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
1921
from . import core
2022

2123
__all__ = [
22-
'save_vars',
23-
'save_params',
24-
'save_persistables',
25-
'load_vars',
26-
'load_params',
27-
'load_persistables',
28-
'save_inference_model',
29-
'load_inference_model',
30-
'get_inference_program',
24+
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
25+
'load_persistables', 'save_inference_model', 'load_inference_model',
26+
'get_inference_program', 'save_checkpoint', 'load_checkpoint',
27+
'clean_checkpoint'
3128
]
3229

3330

@@ -195,6 +192,8 @@ def load_vars(executor,
195192
load_var_map = {}
196193
for each_var in vars:
197194
assert isinstance(each_var, Variable)
195+
if each_var.type == core.VarDesc.VarType.RAW:
196+
continue
198197
new_var = _clone_var_in_block_(load_block, each_var)
199198
if filename is None:
200199
load_block.append_op(
@@ -454,3 +453,192 @@ def get_parameter_value_by_name(name, executor, program=None):
454453
program = default_main_program()
455454
var = program.global_block().var(name)
456455
return get_parameter_value(var, executor)
456+
457+
458+
SUCCESS_MARK_FILENAME = "_SUCCESS"
459+
CHECKPOINT_PREFIX = "checkpoint"
460+
CHECKPOINT_SEPARATOR = "_"
461+
462+
463+
def save_checkpoint(executor,
464+
checkpoint_dir=None,
465+
max_num_checkpoints=3,
466+
save_interval_secs=600,
467+
main_program=None):
468+
"""
469+
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
470+
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
471+
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
472+
The interval between two saved checkpoints must greater than save_interval_secs.
473+
474+
:param executor
475+
:param checkpoint_dir
476+
:param max_num_checkpoints
477+
:param save_interval_secs
478+
:param main_program
479+
"""
480+
if checkpoint_dir is None:
481+
checkpoint_dir = os.getcwd()
482+
483+
if not os.path.isdir(checkpoint_dir):
484+
os.makedirs(checkpoint_dir)
485+
486+
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
487+
if serial >= 0 and not _interval_secs_exceed(
488+
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
489+
return
490+
491+
serial += 1
492+
cur_dir = _get_serial_dir(serial, checkpoint_dir)
493+
494+
save_vars(
495+
executor,
496+
dirname=cur_dir,
497+
main_program=main_program,
498+
vars=None,
499+
predicate=_is_checkpoint_var,
500+
filename=None)
501+
_write_success(cur_dir)
502+
_lru_delete(checkpoint_dir, max_num_checkpoints)
503+
504+
505+
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
506+
"""
507+
Load checkpoint from a directory by executor,
508+
it will find the most recent saved checkpoint file and load it auto.
509+
510+
:param executor
511+
:param checkpoint_dir
512+
:param main_program
513+
"""
514+
515+
if checkpoint_dir is None:
516+
checkpoint_dir = os.getcwd()
517+
518+
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
519+
520+
if serial < 0:
521+
return
522+
523+
cur_dir = _get_serial_dir(serial, checkpoint_dir)
524+
525+
load_vars(
526+
executor,
527+
dirname=cur_dir,
528+
main_program=main_program,
529+
predicate=_is_checkpoint_var,
530+
filename=None)
531+
532+
533+
def clean_checkpoint(checkpoint_dir, delete_dir=False):
534+
"""
535+
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
536+
delete_dir only works when the directory is empty, otherwise, OSError is raised.
537+
"""
538+
if checkpoint_dir is None:
539+
checkpoint_dir = os.getcwd()
540+
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
541+
542+
if delete_dir and not os.listdir(checkpoint_dir):
543+
os.rmdir(checkpoint_dir)
544+
545+
546+
def _get_serial_dir(serial, checkpoint_dir):
547+
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
548+
return os.path.join(checkpoint_dir, serial_folder)
549+
550+
551+
def _is_checkpoint_var(var):
552+
"""
553+
the checkpoint will not save or load all the variables.
554+
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
555+
556+
:param var
557+
"""
558+
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
559+
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
560+
var.desc.type() == core.VarDesc.VarType.RAW:
561+
return False
562+
563+
if var.name.endswith("@GRAD"):
564+
return False
565+
566+
return var.persistable
567+
568+
569+
def _interval_secs_exceed(dirname, save_interval_secs):
570+
dir_time = os.path.getmtime(dirname)
571+
if save_interval_secs > (time.time() - dir_time):
572+
return False
573+
return True
574+
575+
576+
def _lru_delete(dirname, max_num_checkpoints=3):
577+
dirs = os.listdir(dirname)
578+
serials = []
579+
for serial in dirs:
580+
try:
581+
serials.append(int(serial))
582+
except ValueError:
583+
continue
584+
585+
if len(serials) <= max_num_checkpoints:
586+
return
587+
588+
serials.sort(reverse=True)
589+
serials = serials[max_num_checkpoints:]
590+
for serial in serials:
591+
cur_dir = os.path.join(dirname, str(serial))
592+
shutil.rmtree(cur_dir)
593+
594+
595+
def _write_success(dirname):
596+
"""
597+
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
598+
599+
:param dirname
600+
"""
601+
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
602+
with open(success_file, 'a') as f:
603+
now = time.ctime()
604+
f.write(now)
605+
606+
607+
def _get_lastest_checkpoint_dir(checkpoint_dir):
608+
"""
609+
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
610+
611+
:param checkpoint_dir
612+
"""
613+
if not checkpoint_dir.strip():
614+
return -1
615+
616+
def has_success(checkpoint_dir, cur_dir):
617+
"""
618+
is _SUCCESS in this dir
619+
"""
620+
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
621+
622+
try:
623+
int(serial)
624+
except ValueError:
625+
return -1
626+
627+
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
628+
return -1
629+
630+
success_path = os.path.join(
631+
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
632+
if os.path.isfile(success_path):
633+
return int(serial)
634+
635+
if not os.path.isdir(checkpoint_dir):
636+
return -1
637+
638+
current_dir = -1
639+
dirs = os.listdir(checkpoint_dir)
640+
for cur_dir in dirs:
641+
success_num = has_success(checkpoint_dir, cur_dir)
642+
if success_num > current_dir:
643+
current_dir = success_num
644+
return current_dir

0 commit comments

Comments
 (0)