Skip to content

Commit b8ee0d6

Browse files
committed
format save_step output information
1 parent 089db95 commit b8ee0d6

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

core/trainers/framework/runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def save_persistables():
420420
dirname = envs.get_global_env(name + "save_checkpoint_path", None)
421421
if dirname is None or dirname == "":
422422
return
423-
dirname = os.path.join(dirname, str(epoch_id))
423+
dirname = os.path.join(dirname, "epoch_" + str(epoch_id))
424424
logging.info("\tsave epoch_id:%d model into: \"%s\"" %
425425
(epoch_id, dirname))
426426
if is_fleet:
@@ -436,9 +436,11 @@ def save_checkpoint_step():
436436
dirname = envs.get_global_env(name + "save_step_path", None)
437437
if dirname is None or dirname == "":
438438
return
439-
dirname = os.path.join(dirname, str(batch_id))
440-
logging.info("\tsave batch_id:%d model into: \"%s\"" %
441-
(batch_id, dirname))
439+
dirname = os.path.join(dirname,
440+
"epoch_" + str(context["current_epoch"]) +
441+
"_batch_" + str(batch_id))
442+
logging.info("\tsave epoch_id:%d, batch_id:%d model into: \"%s\"" %
443+
(context["current_epoch"], batch_id, dirname))
442444
if is_fleet:
443445
if context["fleet"].worker_index() == 0:
444446
context["fleet"].save_persistables(context["exe"], dirname)

0 commit comments

Comments
 (0)