Skip to content

Commit e94c91d

Browse files
authored
refix disp_dict when distributed (#700)
1 parent 65554a5 commit e94c91d

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

tools/train_utils/train_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,16 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
6161
avg_forward_time = commu_utils.average_reduce_value(cur_forward_time)
6262
avg_batch_time = commu_utils.average_reduce_value(cur_batch_time)
6363

64+
# log to console and tensorboard
6465
if rank == 0:
6566
data_time.update(avg_data_time)
6667
forward_time.update(avg_forward_time)
6768
batch_time.update(avg_batch_time)
69+
disp_dict.update({
70+
'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
71+
'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
72+
})
6873

69-
70-
disp_dict.update({
71-
'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
72-
'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
73-
})
74-
75-
# log to console and tensorboard
76-
if rank == 0:
7774
pbar.update()
7875
pbar.set_postfix(dict(total_it=accumulated_iter))
7976
tbar.set_postfix(disp_dict)

0 commit comments

Comments
 (0)