Skip to content

Commit 94a995d

Browse files
committed
build log_str dynamicly
1 parent edc4ec4 commit 94a995d

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

ssd/engine/trainer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,21 @@ def do_train(cfg, model,
8686
end = time.time()
8787
if iteration % args.log_step == 0:
8888
eta_seconds = int((trained_time / iteration) * (max_iter - iteration))
89-
logger.info(
90-
"Iter: {:06d}, Lr: {:.5f}, Cost: {:.2f}s, Eta: {}, ".format(iteration, optimizer.param_groups[0]['lr'],
91-
time.time() - tic,
92-
str(datetime.timedelta(seconds=eta_seconds))) +
93-
"Loss: {:.3f}, ".format(losses_reduced.item()) +
94-
"Regression Loss {:.3f}, ".format(loss_dict_reduced['regression_loss'].item()) +
95-
"Classification Loss: {:.3f}".format(loss_dict_reduced['classification_loss'].item()))
96-
89+
log_str = [
90+
"Iter: {:06d}, Lr: {:.5f}, Cost: {:.2f}s, Eta: {}".format(iteration,
91+
optimizer.param_groups[0]['lr'],
92+
time.time() - tic, str(datetime.timedelta(seconds=eta_seconds))),
93+
"total_loss: {:.3f}".format(losses_reduced.item())
94+
]
95+
for loss_name, loss_item in loss_dict_reduced.items():
96+
log_str.append("{}: {:.3f}".format(loss_name, loss_item.item()))
97+
log_str = ', '.join(log_str)
98+
logger.info(log_str)
9799
if summary_writer:
98100
global_step = iteration
99-
summary_writer.add_scalar('losses/total_loss', losses_reduced.item(), global_step=global_step)
100-
summary_writer.add_scalar('losses/location_loss', loss_dict_reduced['regression_loss'].item(), global_step=global_step)
101-
summary_writer.add_scalar('losses/class_loss', loss_dict_reduced['classification_loss'].item(), global_step=global_step)
101+
summary_writer.add_scalar('losses/total_loss', losses_reduced, global_step=global_step)
102+
for loss_name, loss_item in loss_dict_reduced.items():
103+
summary_writer.add_scalar('losses/{}'.format(loss_name), loss_item, global_step=global_step)
102104
summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step=global_step)
103105

104106
tic = time.time()

0 commit comments

Comments
 (0)