Skip to content

Commit 8c04cbd

Browse files
authored
Update train.py
1 parent ec67409 commit 8c04cbd

File tree

1 file changed

+61
-50
lines changed

1 file changed

+61
-50
lines changed

train.py

Lines changed: 61 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from utils.dataloader import yolo_dataset_collate, YoloDataset
1515
from nets.yolo_training import YOLOLoss,Generator
1616
from nets.yolo4 import YoloBody
17-
17+
from tqdm import tqdm
1818

1919
#---------------------------------------------------#
2020
# 获得类和先验框
@@ -33,67 +33,78 @@ def get_anchors(anchors_path):
3333
anchors = [float(x) for x in anchors.split(',')]
3434
return np.array(anchors).reshape([-1,3,2])[::-1,:,:]
3535

36-
def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epoch,cuda):
36+
def get_lr(optimizer):
37+
for param_group in optimizer.param_groups:
38+
return param_group['lr']
39+
40+
def fit_one_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epoch,cuda):
3741
total_loss = 0
3842
val_loss = 0
3943
start_time = time.time()
40-
for iteration, batch in enumerate(gen):
41-
if iteration >= epoch_size:
42-
break
43-
images, targets = batch[0], batch[1]
44-
with torch.no_grad():
45-
if cuda:
46-
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
47-
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
48-
else:
49-
images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
50-
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
51-
optimizer.zero_grad()
52-
outputs = net(images)
53-
losses = []
54-
for i in range(3):
55-
loss_item = yolo_losses[i](outputs[i], targets)
56-
losses.append(loss_item[0])
57-
loss = sum(losses)
58-
loss.backward()
59-
optimizer.step()
60-
61-
total_loss += loss
62-
waste_time = time.time() - start_time
63-
print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch))
64-
print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Total Loss: %.4f || %.4fs/step' % (total_loss/(iteration+1),waste_time))
65-
start_time = time.time()
66-
67-
print('Start Validation')
68-
for iteration, batch in enumerate(genval):
69-
if iteration >= epoch_size_val:
70-
break
71-
images_val, targets_val = batch[0], batch[1]
72-
73-
with torch.no_grad():
74-
if cuda:
75-
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)).cuda()
76-
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
77-
else:
78-
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor))
79-
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
44+
with tqdm(total=epoch_size,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
45+
for iteration, batch in enumerate(gen):
46+
if iteration >= epoch_size:
47+
break
48+
images, targets = batch[0], batch[1]
49+
with torch.no_grad():
50+
if cuda:
51+
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
52+
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
53+
else:
54+
images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
55+
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
8056
optimizer.zero_grad()
81-
outputs = net(images_val)
57+
outputs = net(images)
8258
losses = []
8359
for i in range(3):
84-
loss_item = yolo_losses[i](outputs[i], targets_val)
60+
loss_item = yolo_losses[i](outputs[i], targets)
8561
losses.append(loss_item[0])
8662
loss = sum(losses)
87-
val_loss += loss
63+
loss.backward()
64+
optimizer.step()
65+
66+
total_loss += loss
67+
waste_time = time.time() - start_time
68+
69+
pbar.set_postfix(**{'total_loss': total_loss.item() / (iteration + 1),
70+
'lr' : get_lr(optimizer),
71+
'step/s' : waste_time})
72+
pbar.update(1)
73+
74+
start_time = time.time()
75+
76+
print('Start Validation')
77+
with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
78+
for iteration, batch in enumerate(genval):
79+
if iteration >= epoch_size_val:
80+
break
81+
images_val, targets_val = batch[0], batch[1]
82+
83+
with torch.no_grad():
84+
if cuda:
85+
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)).cuda()
86+
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
87+
else:
88+
images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor))
89+
targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val]
90+
optimizer.zero_grad()
91+
outputs = net(images_val)
92+
losses = []
93+
for i in range(3):
94+
loss_item = yolo_losses[i](outputs[i], targets_val)
95+
losses.append(loss_item[0])
96+
loss = sum(losses)
97+
val_loss += loss
98+
pbar.set_postfix(**{'total_loss': val_loss.item() / (iteration + 1)})
99+
pbar.update(1)
100+
88101
print('Finish Validation')
89-
print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch))
102+
print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
90103
print('Total Loss: %.4f || Val Loss: %.4f ' % (total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
91104

92105
print('Saving state, iter:', str(epoch+1))
93106
torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_loss/(epoch_size_val+1)))
94107

95-
96-
97108
#----------------------------------------------------#
98109
# 检测精度mAP和pr曲线计算参考视频
99110
# https://www.bilibili.com/video/BV1zE411u7Vw
@@ -209,7 +220,7 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
209220
param.requires_grad = False
210221

211222
for epoch in range(Init_Epoch,Freeze_Epoch):
212-
fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Freeze_Epoch,Cuda)
223+
fit_one_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Freeze_Epoch,Cuda)
213224
lr_scheduler.step()
214225

215226
if True:
@@ -246,5 +257,5 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
246257
param.requires_grad = True
247258

248259
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
249-
fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda)
260+
fit_one_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch,Cuda)
250261
lr_scheduler.step()

0 commit comments

Comments
 (0)