Skip to content

Commit 4af2db0

Browse files
authored
[CheckpointSaver] Add saving listeners support for increment checkpoint saver. (#915)
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent e2037de commit 4af2db0

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

tensorflow/python/training/basic_session_run_hooks.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,7 @@ def after_run(self, run_context, run_values):
610610
global_step = run_context.session.run(self._global_step_tensor)
611611
if self._incremental_timer.should_trigger_for_step(global_step):
612612
self._incremental_timer.update_last_triggered_step(global_step)
613-
logging.info("Start Save incremental checkpoints for %d into %s.", global_step, self._incremental_save_path)
614-
self._get_incr_saver().incremental_save(run_context.session, self._incremental_save_path, global_step=global_step)
615-
logging.info("Finish Save incremental checkpoints for %d into %s.", global_step, self._incremental_save_path)
613+
self._incr_save(run_context.session, global_step)
616614

617615

618616
def end(self, session):
@@ -666,6 +664,18 @@ def _get_saver(self):
666664
self._saver = savers[0]
667665
return savers[0]
668666

667+
def _incr_save(self, session, step):
668+
logging.info("Saving incremental checkpoints for %d into %s.", step,
669+
self._incremental_save_path)
670+
for l in self._listeners:
671+
l.before_save(session, step)
672+
673+
self._get_incr_saver().incremental_save(session,
674+
self._incremental_save_path,
675+
global_step=step)
676+
for l in self._listeners:
677+
l.after_save(session, step)
678+
669679
def _get_incr_saver(self):
670680
if self._scaffold is not None:
671681
return self._scaffold._incr_saver

tensorflow/python/training/monitored_session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,8 @@ def MonitoredTrainingSession(
491491
save_checkpoint_steps=USE_DEFAULT,
492492
summary_dir=None,
493493
save_incremental_checkpoint_secs=None,
494-
target_nodes_or_tensors=None):
494+
target_nodes_or_tensors=None,
495+
saving_listeners=None):
495496

496497
"""Creates a `MonitoredSession` for training.
497498
@@ -548,6 +549,9 @@ def MonitoredTrainingSession(
548549
summaries. If None, checkpoint_dir is used instead.
549550
target_nodes_or_tensors: list of tf.Tensor or tf.Operation indicates
550551
targets, which determine graph transformation of 'smart-stage'
552+
saving_listeners: List of `CheckpointSaverListener` subclass instances. Used
553+
for callbacks that run immediately before or after this hook saves the
554+
checkpoint.
551555
552556
Returns:
553557
A `MonitoredSession` object.
@@ -648,6 +652,7 @@ def MonitoredTrainingSession(
648652
save_steps=save_checkpoint_steps,
649653
save_secs=save_checkpoint_secs,
650654
scaffold=scaffold,
655+
listeners=saving_listeners,
651656
incremental_save_secs=save_incremental_checkpoint_secs))
652657

653658
if hooks:

0 commit comments

Comments
 (0)