Skip to content

Commit 6db240d

Browse files
committed
update trainer about epoch_id and step id
1 parent 3b5e3f9 commit 6db240d

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

python/paddle/fluid/trainer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def __init__(self,
188188
if not self.checkpoint.is_pserver:
189189
epoch_id, step_id = io.load_trainer_args(
190190
self.checkpoint.checkpoint_dir, self.checkpoint.load_serial,
191-
self.trainer_id, ["epoch_id", "step_id"])
191+
self.trainer_id, self._get_checkpoint_load_args())
192192
self.checkpoint.epoch_id = int(epoch_id)
193193
self.checkpoint.step_id = int(step_id)
194194

@@ -432,22 +432,33 @@ def _clean_checkpoint(self):
432432
return
433433
io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir)
434434

435+
def _get_checkpoint_load_args(self):
436+
"""
437+
epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
438+
"""
439+
return ["epoch_id", "step_id"]
440+
441+
def _get_checkpoint_save_args(self, epoch_id, step_id):
442+
"""
443+
epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
444+
"""
445+
trainer_args = {}
446+
trainer_args["epoch_id"] = epoch_id
447+
trainer_args["step_id"] = step_id
448+
return trainer_args
449+
435450
def _save_checkpoint(self, epoch_id, step_id):
436451
if not self.checkpoint:
437452
return
438453

439454
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
440-
trainer_args = {}
441-
trainer_args["epoch_id"] = epoch_id
442-
trainer_args["step_id"] = step_id
443-
444455
exe = executor.Executor(self.place)
445456
io.save_checkpoint(
446457
executor=exe,
447458
checkpoint_dir=self.checkpoint.checkpoint_dir,
448459
trainer_id=self.trainer_id,
449460
is_chief=self.chief,
450-
trainer_args=trainer_args,
461+
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
451462
main_program=self.train_program,
452463
max_num_checkpoints=self.checkpoint.max_num_checkpoints)
453464

0 commit comments

Comments
 (0)