Skip to content

Commit c1c7677

Browse files
authored
Update train_with_tensorboard.py
1 parent 8c04cbd commit c1c7677

File tree

1 file changed

+69
-53
lines changed

1 file changed

+69
-53
lines changed

train_with_tensorboard.py

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from nets.yolo_training import YOLOLoss,Generator
1616
from nets.yolo4 import YoloBody
1717
from tensorboardX import SummaryWriter
18+
from tqdm import tqdm
1819

1920
#---------------------------------------------------#
2021
# 获得类和先验框
@@ -33,76 +34,91 @@ def get_anchors(anchors_path):
3334
anchors = [float(x) for x in anchors.split(',')]
3435
return np.array(anchors).reshape([-1,3,2])[::-1,:,:]
3536

37+
def get_lr(optimizer):
38+
for param_group in optimizer.param_groups:
39+
return param_group['lr']
40+
3641
def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epoch,cuda,writer):
3742
total_loss = 0
3843
val_loss = 0
3944
start_time = time.time()
40-
for iteration, batch in enumerate(gen):
41-
if iteration >= epoch_size:
42-
break
43-
images, targets = batch[0], batch[1]
44-
with torch.no_grad():
45-
if cuda:
46-
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
47-
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
48-
else:
49-
images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
50-
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
51-
optimizer.zero_grad()
52-
outputs = net(images)
53-
losses = []
54-
for i in range(3):
55-
loss_item = yolo_losses[i](outputs[i], targets)
56-
losses.append(loss_item[0])
57-
loss = sum(losses)
58-
loss.backward()
59-
optimizer.step()
60-
# 将loss写入tensorboard,每一步都写
61-
writer.add_scalar('Train_loss', loss, (epoch*epoch_size + iteration))
62-
63-
total_loss += loss
64-
waste_time = time.time() - start_time
65-
print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch))
66-
print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Total Loss: %.4f || %.4fs/step' % (total_loss/(iteration+1),waste_time))
67-
start_time = time.time()
68-
# 将loss写入tensorboard,下面注释的是每个世代保存一次
69-
# writer.add_scalar('Train_loss', total_loss/(iteration+1), epoch)
70-
71-
print('Start Validation')
72-
for iteration, batch in enumerate(genval):
73-
if iteration >= epoch_size_val:
74-
break
75-
images_val, targets_val = batch[0], batch[1]
76-
77-
with torch.no_grad():
78-
if cuda:
79-
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)).cuda()
80-
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
81-
else:
82-
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor))
83-
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
45+
with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
46+
for iteration, batch in enumerate(gen):
47+
if iteration >= epoch_size:
48+
break
49+
images, targets = batch[0], batch[1]
50+
with torch.no_grad():
51+
if cuda:
52+
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
53+
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
54+
else:
55+
images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
56+
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
8457
optimizer.zero_grad()
85-
outputs = net(images_val)
58+
outputs = net(images)
8659
losses = []
87-
8860
for i in range(3):
89-
loss_item = yolo_losses[i](outputs[i], targets_val)
61+
loss_item = yolo_losses[i](outputs[i], targets)
9062
losses.append(loss_item[0])
9163
loss = sum(losses)
92-
val_loss += loss
93-
# 将loss写入tensorboard, 下面注释的是每一步都写
94-
# writer.add_scalar('Val_loss',val_loss/(epoch_size_val+1), (epoch*epoch_size_val + iteration))
64+
loss.backward()
65+
optimizer.step()
66+
# 将loss写入tensorboard,每一步都写
67+
writer.add_scalar('Train_loss', loss, (epoch*epoch_size + iteration))
68+
69+
total_loss += loss
70+
waste_time = time.time() - start_time
71+
72+
pbar.set_postfix(**{'total_loss': total_loss.item() / (iteration + 1),
73+
'lr' : get_lr(optimizer),
74+
'step/s' : waste_time})
75+
pbar.update(1)
76+
77+
78+
start_time = time.time()
79+
80+
# 将loss写入tensorboard,下面注释的是每个世代保存一次
81+
# writer.add_scalar('Train_loss', total_loss/(iteration+1), epoch)
82+
83+
print('Start Validation')
84+
with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
85+
for iteration, batch in enumerate(genval):
86+
if iteration >= epoch_size_val:
87+
break
88+
images_val, targets_val = batch[0], batch[1]
89+
90+
with torch.no_grad():
91+
if cuda:
92+
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)).cuda()
93+
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
94+
else:
95+
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor))
96+
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
97+
optimizer.zero_grad()
98+
outputs = net(images_val)
99+
losses = []
100+
101+
for i in range(3):
102+
loss_item = yolo_losses[i](outputs[i], targets_val)
103+
losses.append(loss_item[0])
104+
loss = sum(losses)
105+
val_loss += loss
106+
# 将loss写入tensorboard, 下面注释的是每一步都写
107+
# writer.add_scalar('Val_loss',val_loss/(epoch_size_val+1), (epoch*epoch_size_val + iteration))
108+
109+
pbar.set_postfix(**{'total_loss': val_loss.item() / (iteration + 1)})
110+
pbar.update(1)
111+
95112
# 将loss写入tensorboard,每个世代保存一次
96113
writer.add_scalar('Val_loss',val_loss/(epoch_size_val+1), epoch)
97114
print('Finish Validation')
98-
print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch))
115+
print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
99116
print('Total Loss: %.4f || Val Loss: %.4f ' % (total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
100117

101118
print('Saving state, iter:', str(epoch+1))
102119
torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
103120

104121

105-
106122
if __name__ == "__main__":
107123
#-------------------------------#
108124
# 输入的shape大小
@@ -248,4 +264,4 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
248264

249265
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
250266
fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda,writer)
251-
lr_scheduler.step()
267+
lr_scheduler.step()

0 commit comments

Comments
 (0)